calibration.py 9.3 KB


  1. import os
  2. from datetime import datetime
  3. import argparse
  4. import math
  5. from psychopy import visual, core, event, logging
  6. from device.data_client import NeuracleDataClient
  7. from device.trigger_box import TriggerNeuracle
  8. from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
  9. from settings.config import settings
  10. from bci_core.online import Controller, model_loader
  11. from settings.config import settings
  12. config_info = settings.CONFIG_INFO
  13. def parse_args():
  14. parser = argparse.ArgumentParser(
  15. description='Hand gesture train'
  16. )
  17. parser.add_argument(
  18. '--subj',
  19. dest='subj',
  20. help='Subject name',
  21. default=None,
  22. type=str
  23. )
  24. parser.add_argument(
  25. '--side',
  26. dest='side',
  27. help='train side',
  28. default=None,
  29. type=str
  30. )
  31. parser.add_argument(
  32. '--n-trials',
  33. dest='n_trials',
  34. help='Trial number',
  35. type=int,
  36. )
  37. parser.add_argument(
  38. '--hand-feedback',
  39. dest='hand_feedback',
  40. action='store_true',
  41. )
  42. parser.add_argument(
  43. '--hand-port',
  44. dest='hand_port',
  45. help='Peripheral serial port',
  46. type=str
  47. )
  48. parser.add_argument(
  49. '--trigger-port',
  50. dest='trigger_port',
  51. help='Triggerbox serial port',
  52. type=str
  53. )
  54. parser.add_argument(
  55. '--finger-model',
  56. '-fm',
  57. dest='finger_model',
  58. help='Gesture to train',
  59. type=str
  60. )
  61. parser.add_argument(
  62. '--virtual-feedback-rate',
  63. '-vfr',
  64. dest='virtual_feedback_rate',
  65. help='Virtual feedback rate',
  66. type=float
  67. )
  68. parser.add_argument(
  69. '--model-filename',
  70. dest='model_filename',
  71. help='Model file',
  72. default=None,
  73. type=str
  74. )
  75. return parser.parse_args()
  76. args = parse_args()
  77. # connect neo
  78. receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']),
  79. samplerate=config_info['sample_rate'],
  80. host=config_info['host'],
  81. port=config_info['port'],
  82. buffer_len=config_info['buffer_length'])
  83. # connect to trigger box
  84. trigger = TriggerNeuracle(port=args.trigger_port)
  85. if args.hand_feedback:
  86. # connect to mechanical hand
  87. hand_device = FuboPneumaticFingerClient({'port': args.hand_port})
  88. # build bci controller
  89. if args.model_filename is not None:
  90. model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
  91. control_model = model_loader(model_path)
  92. else:
  93. control_model = None
  94. controller = Controller(args.virtual_feedback_rate,
  95. control_model,
  96. reref_method=config_info['reref'])
  97. time_prepare = 1.5 # in seconds
  98. time_move = 7.
  99. time_blank = 1.
  100. time_rest = 7.
  101. time_feedback = 2.
  102. time_update = 0.2
  103. cnt_threshold_table = {
  104. 'easy': 0.6 * (time_move + time_rest) / time_update,
  105. 'mid': 0.7 * (time_move + time_rest) / time_update,
  106. 'hard': 0.8 * (time_move + time_rest) / time_update
  107. }
  108. # setup logger
  109. logging.console.setLevel(logging.INFO)
  110. datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
  111. log_file_path = os.path.join(settings.DATA_PATH, args.subj, f"calibration_{datetime_now}.log")
  112. if not os.path.isdir(os.path.join(settings.DATA_PATH, args.subj)):
  113. os.mkdir(os.path.join(settings.DATA_PATH, args.subj))
  114. logger = logging.LogFile(log_file_path, level=logging.INFO, filemode='w')
  115. # initialize all components
  116. # setup window
  117. win = visual.Window(
  118. size=[1920, 1080], fullscr=True, screen=0,
  119. winType='pyglet', allowStencil=False,
  120. monitor='testMonitor', color=[1,1,1], colorSpace='rgb',
  121. backgroundImage='', backgroundFit='none',
  122. blendMode='avg', useFBO=True,
  123. units='height'
  124. )
  125. train_position = visual.TextStim(win=win, name='train_position',
  126. text='训练部位:右手' if args.side == 'right' else '训练部位:左手',
  127. font='Open Sans',
  128. pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
  129. color='black', colorSpace='rgb', opacity=None,
  130. languageStyle='LTR',
  131. depth=0.0);
  132. instruction = visual.TextStim(win=win, name='instruction',
  133. text='准备进行一般抓握训练,\n按任意键继续',
  134. font='Open Sans',
  135. pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
  136. color='black', colorSpace='rgb', opacity=None,
  137. languageStyle='LTR',
  138. depth=-1.0)
  139. prepare = visual.TextStim(win=win, name='text',
  140. text='请准备',
  141. font='Open Sans',
  142. pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
  143. color='black', colorSpace='rgb', opacity=None,
  144. languageStyle='LTR',
  145. depth=-1.0)
  146. img_move = visual.ImageStim(
  147. win=win,
  148. name='img_move',
  149. image=f'{settings.IMAGE_PATH}/hand_move_{args.side}.png', mask=None, anchor='center',
  150. ori=0.0, pos=(0, 0), size=None,
  151. color=[1,1,1], colorSpace='rgb', opacity=None,
  152. flipHoriz=False, flipVert=False,
  153. texRes=128.0, interpolate=True, depth=0.0)
  154. img_rest = visual.ImageStim(
  155. win=win,
  156. name='img_rest',
  157. image=f'{settings.IMAGE_PATH}/rest.png', mask=None, anchor='center',
  158. ori=0.0, pos=(0, 0), size=None,
  159. color=[1,1,1], colorSpace='rgb', opacity=None,
  160. flipHoriz=False, flipVert=False,
  161. texRes=128.0, interpolate=True, depth=0.0)
  162. # progress bar
  163. feedback_bar = visual.Progress(
  164. win, name='feedback_bar',
  165. progress=0,
  166. pos=(0.8, -0.25), size=(0.5, 0.1), anchor='bottom-left', units='height',
  167. barColor='black', backColor=None, borderColor='black', colorSpace='rgb',
  168. lineWidth=4.0, opacity=1.0, ori=270.0,
  169. depth=0
  170. )
  171. feedback = visual.TextStim(win=win, name='feedback',
  172. text=None,
  173. font='Open Sans',
  174. pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
  175. color='black', colorSpace='rgb', opacity=None,
  176. languageStyle='LTR',
  177. depth=0.0)
  178. mi_end = visual.TextStim(win=win, name='mi_end',
  179. text='结束实验',
  180. font='Open Sans',
  181. pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
  182. color='black', colorSpace='rgb', opacity=None,
  183. languageStyle='LTR',
  184. depth=0.0)
  185. clock = core.Clock()
  186. def exit():
  187. receiver.close()
  188. win.close()
  189. core.quit()
  190. def check_exit():
  191. if 'escape' in event.getKeys():
  192. exit()
  193. def show_and_judge(img, trial_time, finger_model, hand_feedback=True):
  194. def update_elements():
  195. feedback_bar.draw()
  196. img.draw()
  197. win.flip()
  198. correct_cnt = 0
  199. hand_device_started = False
  200. true_label = settings.FINGERMODEL_IDS[finger_model]
  201. clock.reset()
  202. # send trigger
  203. win.callOnFlip(trigger.send_trigger, true_label)
  204. # reset progress bar
  205. feedback_bar.progress = 0
  206. # draw and flip
  207. update_elements()
  208. # wait trial_time seconds
  209. while True:
  210. clock_time = clock.getTime()
  211. if clock_time >= trial_time:
  212. break
  213. check_exit()
  214. # for each time step
  215. if abs(clock_time % time_update) < 1e-4:
  216. # get data
  217. data_from_buffer = receiver.get_trial_data(clear=False)
  218. decision = controller.step_decision(data_from_buffer, true_label)
  219. if decision == true_label:
  220. correct_cnt += 1
  221. # update bar progress
  222. bar_progress = math.sqrt(correct_cnt / (trial_time / time_update))
  223. feedback_bar.progress = bar_progress
  224. # send hand feedback
  225. if hand_feedback and (not hand_device_started):
  226. if finger_model == 'rest':
  227. hand_device.start('extend')
  228. else:
  229. hand_device.start(finger_model)
  230. hand_device_started = True
  231. # draw and flip
  232. update_elements()
  233. return correct_cnt
  234. def img_trial(time, img=None):
  235. clock.reset()
  236. if img is not None:
  237. img.draw()
  238. win.flip()
  239. while clock.getTime() < time:
  240. check_exit()
  241. def mi_trial(trial_ind):
  242. # mi_prepare
  243. img_trial(time_prepare, prepare)
  244. correct_cnt = 0
  245. # blank
  246. img_trial(time_blank)
  247. # mi_move
  248. correct_cnt += show_and_judge(img_move, time_move, args.finger_model, args.hand_feedback)
  249. # blank
  250. img_trial(time_blank)
  251. # mi_rest
  252. correct_cnt += show_and_judge(img_rest, time_rest, 'rest', args.hand_feedback)
  253. # logging data
  254. current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
  255. logging.exp(f'Trial {trial_ind + 1} correct count: {correct_cnt} - current time: {current_time}')
  256. # mi_feedback
  257. if correct_cnt >= cnt_threshold_table['hard']:
  258. feedback.text = '完美!'
  259. grad = 10
  260. elif correct_cnt >= cnt_threshold_table['mid']:
  261. feedback.text = '优秀!'
  262. grad = 9
  263. else:
  264. feedback.text = '好'
  265. grad = 7
  266. img_trial(time_feedback, feedback)
  267. return grad
  268. def run_exp():
  269. # prepare
  270. clock.reset()
  271. prepare.draw()
  272. win.flip()
  273. event.waitKeys()
  274. # run
  275. grad = 0
  276. for i in range(args.n_trials):
  277. grad += mi_trial(i)
  278. # end exp
  279. clock.reset()
  280. grad = grad / args.n_trials * 10
  281. mi_end.text = f"实验结束,\n得分:{int(grad)}"
  282. mi_end.draw()
  283. win.flip()
  284. logging.exp(f'Exp grade: {int(grad)}')
  285. while True:
  286. check_exit()
  287. run_exp()