123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- import os
- from datetime import datetime
- import argparse
- import math
- from psychopy import visual, core, event, logging
- from device.data_client import NeuracleDataClient
- from device.trigger_box import TriggerNeuracle
- from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
- from settings.config import settings
- from bci_core.online import Controller, model_loader
- from settings.config import settings
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Hand gesture train'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--side',
- dest='side',
- help='train side',
- default=None,
- type=str
- )
- parser.add_argument(
- '--n-trials',
- dest='n_trials',
- help='Trial number',
- type=int,
- )
- parser.add_argument(
- '--hand-feedback',
- dest='hand_feedback',
- action='store_true',
- )
- parser.add_argument(
- '--hand-port',
- dest='hand_port',
- help='Peripheral serial port',
- type=str
- )
- parser.add_argument(
- '--trigger-port',
- dest='trigger_port',
- help='Triggerbox serial port',
- type=str
- )
- parser.add_argument(
- '--finger-model',
- '-fm',
- dest='finger_model',
- help='Gesture to train',
- type=str
- )
- parser.add_argument(
- '--virtual-feedback-rate',
- '-vfr',
- dest='virtual_feedback_rate',
- help='Virtual feedback rate',
- type=float
- )
- parser.add_argument(
- '--model-filename',
- dest='model_filename',
- help='Model file',
- default=None,
- type=str
- )
- return parser.parse_args()
- args = parse_args()
- # connect neo
- receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']),
- samplerate=config_info['sample_rate'],
- host=config_info['host'],
- port=config_info['port'],
- buffer_len=config_info['buffer_length'])
- # connect to trigger box
- trigger = TriggerNeuracle(port=args.trigger_port)
- if args.hand_feedback:
- # connect to mechanical hand
- hand_device = FuboPneumaticFingerClient({'port': args.hand_port})
- # build bci controller
- if args.model_filename is not None:
- model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
- control_model = model_loader(model_path)
- else:
- control_model = None
-
- controller = Controller(args.virtual_feedback_rate,
- control_model,
- reref_method=config_info['reref'])
- time_prepare = 1.5 # in seconds
- time_move = 7.
- time_blank = 1.
- time_rest = 7.
- time_feedback = 2.
- time_update = 0.2
- cnt_threshold_table = {
- 'easy': 0.6 * (time_move + time_rest) / time_update,
- 'mid': 0.7 * (time_move + time_rest) / time_update,
- 'hard': 0.8 * (time_move + time_rest) / time_update
- }
- # setup logger
- logging.console.setLevel(logging.INFO)
- datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
- log_file_path = os.path.join(settings.DATA_PATH, args.subj, f"calibration_{datetime_now}.log")
- if not os.path.isdir(os.path.join(settings.DATA_PATH, args.subj)):
- os.mkdir(os.path.join(settings.DATA_PATH, args.subj))
- logger = logging.LogFile(log_file_path, level=logging.INFO, filemode='w')
- # initialize all components
- # setup window
- win = visual.Window(
- size=[1920, 1080], fullscr=True, screen=0,
- winType='pyglet', allowStencil=False,
- monitor='testMonitor', color=[1,1,1], colorSpace='rgb',
- backgroundImage='', backgroundFit='none',
- blendMode='avg', useFBO=True,
- units='height'
- )
- train_position = visual.TextStim(win=win, name='train_position',
- text='训练部位:右手' if args.side == 'right' else '训练部位:左手',
- font='Open Sans',
- pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
- color='black', colorSpace='rgb', opacity=None,
- languageStyle='LTR',
- depth=0.0);
- instruction = visual.TextStim(win=win, name='instruction',
- text='准备进行一般抓握训练,\n按任意键继续',
- font='Open Sans',
- pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
- color='black', colorSpace='rgb', opacity=None,
- languageStyle='LTR',
- depth=-1.0)
- prepare = visual.TextStim(win=win, name='text',
- text='请准备',
- font='Open Sans',
- pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
- color='black', colorSpace='rgb', opacity=None,
- languageStyle='LTR',
- depth=-1.0)
- img_move = visual.ImageStim(
- win=win,
- name='img_move',
- image=f'{settings.IMAGE_PATH}/hand_move_{args.side}.png', mask=None, anchor='center',
- ori=0.0, pos=(0, 0), size=None,
- color=[1,1,1], colorSpace='rgb', opacity=None,
- flipHoriz=False, flipVert=False,
- texRes=128.0, interpolate=True, depth=0.0)
- img_rest = visual.ImageStim(
- win=win,
- name='img_rest',
- image=f'{settings.IMAGE_PATH}/rest.png', mask=None, anchor='center',
- ori=0.0, pos=(0, 0), size=None,
- color=[1,1,1], colorSpace='rgb', opacity=None,
- flipHoriz=False, flipVert=False,
- texRes=128.0, interpolate=True, depth=0.0)
- # progress bar
- feedback_bar = visual.Progress(
- win, name='feedback_bar',
- progress=0,
- pos=(0.8, -0.25), size=(0.5, 0.1), anchor='bottom-left', units='height',
- barColor='black', backColor=None, borderColor='black', colorSpace='rgb',
- lineWidth=4.0, opacity=1.0, ori=270.0,
- depth=0
- )
- feedback = visual.TextStim(win=win, name='feedback',
- text=None,
- font='Open Sans',
- pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
- color='black', colorSpace='rgb', opacity=None,
- languageStyle='LTR',
- depth=0.0)
- mi_end = visual.TextStim(win=win, name='mi_end',
- text='结束实验',
- font='Open Sans',
- pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
- color='black', colorSpace='rgb', opacity=None,
- languageStyle='LTR',
- depth=0.0)
- clock = core.Clock()
- def exit():
- receiver.close()
- win.close()
- core.quit()
- def check_exit():
- if 'escape' in event.getKeys():
- exit()
-
- def show_and_judge(img, trial_time, finger_model, hand_feedback=True):
- def update_elements():
- feedback_bar.draw()
- img.draw()
- win.flip()
- correct_cnt = 0
- hand_device_started = False
- true_label = settings.FINGERMODEL_IDS[finger_model]
- clock.reset()
- # send trigger
- win.callOnFlip(trigger.send_trigger, true_label)
- # reset progress bar
- feedback_bar.progress = 0
- # draw and flip
- update_elements()
- # wait trial_time seconds
- while True:
- clock_time = clock.getTime()
- if clock_time >= trial_time:
- break
-
- check_exit()
- # for each time step
- if abs(clock_time % time_update) < 1e-4:
- # get data
- data_from_buffer = receiver.get_trial_data(clear=False)
- decision = controller.step_decision(data_from_buffer, true_label)
- if decision == true_label:
- correct_cnt += 1
- # update bar progress
- bar_progress = math.sqrt(correct_cnt / (trial_time / time_update))
- feedback_bar.progress = bar_progress
- # send hand feedback
- if hand_feedback and (not hand_device_started):
- if finger_model == 'rest':
- hand_device.start('extend')
- else:
- hand_device.start(finger_model)
- hand_device_started = True
- # draw and flip
- update_elements()
-
- return correct_cnt
- def img_trial(time, img=None):
- clock.reset()
- if img is not None:
- img.draw()
- win.flip()
- while clock.getTime() < time:
- check_exit()
- def mi_trial(trial_ind):
- # mi_prepare
- img_trial(time_prepare, prepare)
-
- correct_cnt = 0
- # blank
- img_trial(time_blank)
- # mi_move
- correct_cnt += show_and_judge(img_move, time_move, args.finger_model, args.hand_feedback)
- # blank
- img_trial(time_blank)
- # mi_rest
- correct_cnt += show_and_judge(img_rest, time_rest, 'rest', args.hand_feedback)
- # logging data
- current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
- logging.exp(f'Trial {trial_ind + 1} correct count: {correct_cnt} - current time: {current_time}')
- # mi_feedback
- if correct_cnt >= cnt_threshold_table['hard']:
- feedback.text = '完美!'
- grad = 10
- elif correct_cnt >= cnt_threshold_table['mid']:
- feedback.text = '优秀!'
- grad = 9
- else:
- feedback.text = '好'
- grad = 7
- img_trial(time_feedback, feedback)
- return grad
- def run_exp():
- # prepare
- clock.reset()
- prepare.draw()
- win.flip()
- event.waitKeys()
- # run
- grad = 0
- for i in range(args.n_trials):
- grad += mi_trial(i)
-
- # end exp
- clock.reset()
- grad = grad / args.n_trials * 10
- mi_end.text = f"实验结束,\n得分:{int(grad)}"
- mi_end.draw()
- win.flip()
- logging.exp(f'Exp grade: {int(grad)}')
- while True:
- check_exit()
- run_exp()
|