import os from datetime import datetime import argparse import math from psychopy import visual, core, event, logging from device.data_client import NeuracleDataClient from device.trigger_box import TriggerNeuracle from device.fubo_pneumatic_finger import FuboPneumaticFingerClient from settings.config import settings from bci_core.online import Controller, model_loader from settings.config import settings config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Hand gesture train' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--side', dest='side', help='train side', default=None, type=str ) parser.add_argument( '--n-trials', dest='n_trials', help='Trial number', type=int, ) parser.add_argument( '--hand-feedback', dest='hand_feedback', action='store_true', ) parser.add_argument( '--hand-port', dest='hand_port', help='Peripheral serial port', type=str ) parser.add_argument( '--trigger-port', dest='trigger_port', help='Triggerbox serial port', type=str ) parser.add_argument( '--finger-model', '-fm', dest='finger_model', help='Gesture to train', type=str ) parser.add_argument( '--virtual-feedback-rate', '-vfr', dest='virtual_feedback_rate', help='Virtual feedback rate', type=float ) parser.add_argument( '--model-filename', dest='model_filename', help='Model file', default=None, type=str ) return parser.parse_args() args = parse_args() # connect neo receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), samplerate=config_info['sample_rate'], host=config_info['host'], port=config_info['port'], buffer_len=config_info['buffer_length']) # connect to trigger box trigger = TriggerNeuracle(port=args.trigger_port) if args.hand_feedback: # connect to mechanical hand hand_device = FuboPneumaticFingerClient({'port': args.hand_port}) # build bci controller if args.model_filename is not None: model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename) control_model = model_loader(model_path) else: control_model = None controller = Controller(args.virtual_feedback_rate, control_model, reref_method=config_info['reref']) time_prepare = 1.5 # in seconds time_move = 7. time_blank = 1. time_rest = 7. time_feedback = 2. time_update = 0.2 cnt_threshold_table = { 'easy': 0.6 * (time_move + time_rest) / time_update, 'mid': 0.7 * (time_move + time_rest) / time_update, 'hard': 0.8 * (time_move + time_rest) / time_update } # setup logger logging.console.setLevel(logging.INFO) datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S") log_file_path = os.path.join(settings.DATA_PATH, args.subj, f"calibration_{datetime_now}.log") if not os.path.isdir(os.path.join(settings.DATA_PATH, args.subj)): os.mkdir(os.path.join(settings.DATA_PATH, args.subj)) logger = logging.LogFile(log_file_path, level=logging.INFO, filemode='w') # initialize all components # setup window win = visual.Window( size=[1920, 1080], fullscr=True, screen=0, winType='pyglet', allowStencil=False, monitor='testMonitor', color=[1,1,1], colorSpace='rgb', backgroundImage='', backgroundFit='none', blendMode='avg', useFBO=True, units='height' ) train_position = visual.TextStim(win=win, name='train_position', text='训练部位:右手' if args.side == 'right' else '训练部位:左手', font='Open Sans', pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0, color='black', colorSpace='rgb', opacity=None, languageStyle='LTR', depth=0.0); instruction = visual.TextStim(win=win, name='instruction', text='准备进行一般抓握训练,\n按任意键继续', font='Open Sans', pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0, color='black', colorSpace='rgb', opacity=None, languageStyle='LTR', depth=-1.0) prepare = visual.TextStim(win=win, name='text', text='请准备', font='Open Sans', pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0, color='black', colorSpace='rgb', opacity=None, languageStyle='LTR', depth=-1.0) img_move = visual.ImageStim( win=win, name='img_move', image=f'{settings.IMAGE_PATH}/hand_move_{args.side}.png', mask=None, anchor='center', ori=0.0, pos=(0, 0), size=None, color=[1,1,1], colorSpace='rgb', opacity=None, flipHoriz=False, flipVert=False, texRes=128.0, interpolate=True, depth=0.0) img_rest = visual.ImageStim( win=win, name='img_rest', image=f'{settings.IMAGE_PATH}/rest.png', mask=None, anchor='center', ori=0.0, pos=(0, 0), size=None, color=[1,1,1], colorSpace='rgb', opacity=None, flipHoriz=False, flipVert=False, texRes=128.0, interpolate=True, depth=0.0) # progress bar feedback_bar = visual.Progress( win, name='feedback_bar', progress=0, pos=(0.8, -0.25), size=(0.5, 0.1), anchor='bottom-left', units='height', barColor='black', backColor=None, borderColor='black', colorSpace='rgb', lineWidth=4.0, opacity=1.0, ori=270.0, depth=0 ) feedback = visual.TextStim(win=win, name='feedback', text=None, font='Open Sans', pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0, color='black', colorSpace='rgb', opacity=None, languageStyle='LTR', depth=0.0) mi_end = visual.TextStim(win=win, name='mi_end', text='结束实验', font='Open Sans', pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0, color='black', colorSpace='rgb', opacity=None, languageStyle='LTR', depth=0.0) clock = core.Clock() def exit(): receiver.close() win.close() core.quit() def check_exit(): if 'escape' in event.getKeys(): exit() def show_and_judge(img, trial_time, finger_model, hand_feedback=True): def update_elements(): feedback_bar.draw() img.draw() win.flip() correct_cnt = 0 hand_device_started = False true_label = settings.FINGERMODEL_IDS[finger_model] clock.reset() # send trigger win.callOnFlip(trigger.send_trigger, true_label) # reset progress bar feedback_bar.progress = 0 # draw and flip update_elements() # wait trial_time seconds while True: clock_time = clock.getTime() if clock_time >= trial_time: break check_exit() # for each time step if abs(clock_time % time_update) < 1e-4: # get data data_from_buffer = receiver.get_trial_data(clear=False) decision = controller.step_decision(data_from_buffer, true_label) if decision == true_label: correct_cnt += 1 # update bar progress bar_progress = math.sqrt(correct_cnt / (trial_time / time_update)) feedback_bar.progress = bar_progress # send hand feedback if hand_feedback and (not hand_device_started): if finger_model == 'rest': hand_device.start('extend') else: hand_device.start(finger_model) hand_device_started = True # draw and flip update_elements() return correct_cnt def img_trial(time, img=None): clock.reset() if img is not None: img.draw() win.flip() while clock.getTime() < time: check_exit() def mi_trial(trial_ind): # mi_prepare img_trial(time_prepare, prepare) correct_cnt = 0 # blank img_trial(time_blank) # mi_move correct_cnt += show_and_judge(img_move, time_move, args.finger_model, args.hand_feedback) # blank img_trial(time_blank) # mi_rest correct_cnt += show_and_judge(img_rest, time_rest, 'rest', args.hand_feedback) # logging data current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") logging.exp(f'Trial {trial_ind + 1} correct count: {correct_cnt} - current time: {current_time}') # mi_feedback if correct_cnt >= cnt_threshold_table['hard']: feedback.text = '完美!' grad = 10 elif correct_cnt >= cnt_threshold_table['mid']: feedback.text = '优秀!' grad = 9 else: feedback.text = '好' grad = 7 img_trial(time_feedback, feedback) return grad def run_exp(): # prepare clock.reset() prepare.draw() win.flip() event.waitKeys() # run grad = 0 for i in range(args.n_trials): grad += mi_trial(i) # end exp clock.reset() grad = grad / args.n_trials * 10 mi_end.text = f"实验结束,\n得分:{int(grad)}" mi_end.draw() win.flip() logging.exp(f'Exp grade: {int(grad)}') while True: check_exit() run_exp()