free_grasp.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import os
  2. import argparse
  3. from datetime import datetime
  4. import random
  5. from psychopy import visual, core, event, logging
  6. from device.data_client import NeuracleDataClient
  7. from device.trigger_box import TriggerNeuracle
  8. from device.fubo_pneumatic_finger import FingerController
  9. from device.arduino import send_toggle_command,connect_to_arduino
  10. from settings.config import settings
  11. from bci_core.online import Controller, model_loader
  12. config_info = settings.CONFIG_INFO
  13. fingermodel_ids_inverse = settings.FINGERMODEL_IDS_INVERSE
  14. # get train params
  15. def parse_args():
  16. parser = argparse.ArgumentParser(
  17. description='Hand gesture train'
  18. )
  19. parser.add_argument(
  20. '--subj',
  21. dest='subj',
  22. help='Subject name',
  23. default=None,
  24. type=str
  25. )
  26. parser.add_argument(
  27. '--hand-port',
  28. dest='hand_port',
  29. help='Peripheral serial port',
  30. type=str
  31. )
  32. parser.add_argument(
  33. '--trigger-port',
  34. dest='trigger_port',
  35. help='Triggerbox serial port',
  36. type=str
  37. )
  38. parser.add_argument(
  39. '--arduino-port',
  40. dest='arduino_port',
  41. help='Arduino serial port',
  42. type=str
  43. )
  44. parser.add_argument(
  45. '--state-change-threshold',
  46. '-scth',
  47. dest='state_change_threshold',
  48. help='Threshold for HMM state change',
  49. type=float
  50. )
  51. parser.add_argument(
  52. '--state-trans-prob',
  53. '-stp',
  54. dest='state_trans_prob',
  55. help='Transition probability for HMM state change',
  56. default=0.8,
  57. type=float
  58. )
  59. parser.add_argument(
  60. '--momentum',
  61. help='Probability update momentum',
  62. default=0.5,
  63. type=float
  64. )
  65. parser.add_argument(
  66. '--model-filename',
  67. dest='model_filename',
  68. help='Model file',
  69. default=None,
  70. type=str
  71. )
  72. parser.add_argument(
  73. '--debug',
  74. action='store_true',
  75. help='Store true to debug progress'
  76. )
  77. return parser.parse_args()
  78. args = parse_args()
  79. # setup logger
  80. logging.console.setLevel(logging.INFO)
  81. datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
  82. log_file_path = os.path.join(settings.DATA_PATH, args.subj, f"freegrasping_{datetime_now}.log")
  83. logger = logging.LogFile(log_file_path, level=logging.INFO, filemode='w')
  84. # initialize devices and models
  85. if not args.debug:
  86. # load model
  87. model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
  88. input_kwargs = {
  89. 'state_trans_prob': args.state_trans_prob,
  90. 'state_change_threshold': args.state_change_threshold,
  91. 'momentum': args.momentum
  92. }
  93. control_model = model_loader(model_path, **input_kwargs)
  94. # build bci controller
  95. controller = Controller(0., control_model, reref_method=config_info['reref'])
  96. # connect neo
  97. receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']),
  98. samplerate=config_info['sample_rate'],
  99. host=config_info['host'],
  100. port=config_info['port'],
  101. buffer_len=config_info['buffer_length'])
  102. # connect to trigger box
  103. trigger = TriggerNeuracle(port=args.trigger_port)
  104. # connect to mechanical hand
  105. hand_device = FingerController(mode='step', init_params={'port': args.hand_port})
  106. serial_port = args.arduino_port
  107. baud_rate = 9600
  108. # Connect to Arduino
  109. ser = connect_to_arduino(serial_port, baud_rate)
  110. # setup window
  111. win = visual.Window(
  112. size=[1920, 1080], fullscr=True, screen=0,
  113. winType='pyglet', allowStencil=False,
  114. monitor='testMonitor', color=[1,1,1], colorSpace='rgb',
  115. backgroundImage='', backgroundFit='none',
  116. blendMode='avg', useFBO=True,
  117. units='height'
  118. )
  119. fps = win.getActualFrameRate()
  120. update_interval = 0.1 # second
  121. # --- Initialize components for Routine "decision" ---
  122. feedback_bar = visual.Progress(
  123. win, name='feedback_bar',
  124. progress=0,
  125. pos=(0, -0.25), size=(0.5, 0.1), anchor='bottom-left', units='height',
  126. barColor='black', backColor=None, borderColor='black', colorSpace='rgb',
  127. lineWidth=4.0, opacity=1.0, ori=270.0,
  128. depth=0
  129. )
  130. text_message = visual.TextStim(win=win, name='text',
  131. text='您将在接下来的任务中自主控制气动手,\n进度条提示您当前时刻的抓握力度。\n希望气动手握紧请用力尝试握手,\n希望气动手松开请尝试放松。\n按空格键继续',
  132. font='Open Sans',
  133. pos=(0, 0), height=0.05, wrapWidth=None, ori=0.0,
  134. color='black', colorSpace='rgb', opacity=None,
  135. languageStyle='LTR',
  136. depth=0.0)
  137. text_message.draw()
  138. win.flip()
  139. event.waitKeys()
  140. while True:
  141. keys = event.getKeys()
  142. if 'escape' in keys:
  143. break
  144. if not args.debug:
  145. data_from_buffer = receiver.get_trial_data(clear=False)
  146. decision = controller.decision(data_from_buffer, None)
  147. force = controller.real_feedback_model.probability[1]
  148. else:
  149. decision = random.randint(-1, 1)
  150. force = random.random()
  151. current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
  152. logging.exp(f'Prob: {force} - current time: {current_time}')
  153. feedback_bar.progress = force
  154. feedback_bar.draw()
  155. win.flip()
  156. if decision != -1:
  157. # logging decision change
  158. current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
  159. logging.exp(f'Decision: {decision:d} - current time: {current_time}')
  160. if not args.debug:
  161. trigger.send_trigger(int(decision))
  162. hand_device.move(action=fingermodel_ids_inverse[decision])
  163. send_toggle_command(ser)
  164. if decision == 0: # only when state changes to rest, give a freeze time
  165. core.wait(3)
  166. else:
  167. core.wait(0.1)
  168. # exit exp
  169. if not args.debug:
  170. receiver.close()
  171. logging.flush()