neo.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import numpy as np
  2. import os
  3. import json
  4. import mne
  5. import glob
  6. import pyedflib
  7. from .utils import upsample_events
  8. FINGERMODEL_IDS = {
  9. 'rest': 0,
  10. 'cylinder': 1,
  11. 'ball': 2,
  12. 'flex': 3,
  13. 'double': 4,
  14. 'treble': 5
  15. }
  16. def raw_preprocessing(data_root, session_paths:dict,
  17. upsampled_epoch_length=1.,
  18. ori_epoch_length=5,
  19. rename_event=True):
  20. """
  21. Params:
  22. subj_root:
  23. session_paths: dict of lists
  24. upsampled_epoch_length:
  25. ori_epoch_length (int or 'varied'): original epoch length in second
  26. rename_event (True, use unified event label, False use original)
  27. """
  28. raws_loaded = load_sessions(data_root, session_paths)
  29. # process event
  30. raws = []
  31. for (finger_model, raw) in raws_loaded:
  32. fs = raw.info['sfreq']
  33. events, _ = mne.events_from_annotations(raw)
  34. mov_trial_ind = [2, 3]
  35. rest_trial_ind = [4]
  36. if not rename_event:
  37. mov_trial_ind = [finger_model]
  38. rest_trial_ind = [0]
  39. if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
  40. trial_duration = ori_epoch_length
  41. elif ori_epoch_length == 'varied':
  42. trial_duration = None
  43. else:
  44. raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
  45. events = reconstruct_events(events, fs, finger_model,
  46. mov_trial_ind=mov_trial_ind,
  47. rest_trial_ind=rest_trial_ind,
  48. trial_duration=trial_duration,
  49. use_original_label=not rename_event)
  50. events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
  51. annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
  52. raw.set_annotations(annotations)
  53. raws.append(raw)
  54. raws = mne.concatenate_raws(raws)
  55. raws.load_data()
  56. # filter 50Hz
  57. raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
  58. return raws
  59. def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
  60. """重构出事件序列中的单独运动事件
  61. Args:
  62. fs: int
  63. finger_model:
  64. """
  65. # Trial duration are fixed to be ? seconds.
  66. # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
  67. # ignore initialRest
  68. # extract trials
  69. deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
  70. deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
  71. trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
  72. events_new = events[trials_ind_deduplicated]
  73. if trial_duration is None:
  74. events_new[:-1, 1] = np.diff(events_new[:, 0])
  75. events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
  76. else:
  77. events_new[:, 1] = int(trial_duration * fs)
  78. events_final = events_new.copy()
  79. if not use_original_label and finger_model is not None:
  80. # process mov
  81. ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
  82. events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model]
  83. # process rest
  84. ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
  85. events_final[ind_rest, 2] = 0
  86. return events_final
  87. def load_sessions(data_root, session_names: dict):
  88. # return raws for different finger models on an interleaved manner
  89. raw_cnt = sum(len(session_names[k]) for k in session_names)
  90. raws = []
  91. i = 0
  92. while i < raw_cnt:
  93. for finger_model in session_names.keys():
  94. try:
  95. s = session_names[finger_model].pop(0)
  96. i += 1
  97. except IndexError:
  98. continue
  99. if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
  100. # neo format
  101. raw = load_neuracle(os.path.join(data_root, s))
  102. else:
  103. # kraken format
  104. data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
  105. raw = mne.io.read_raw_bdf(data_file)
  106. raws.append((finger_model, raw))
  107. return raws
  108. def load_neuracle(data_dir, data_type='ecog'):
  109. """
  110. neuracle file loader
  111. :param
  112. data_dir: root data dir for the experiment
  113. sfreq:
  114. data_type:
  115. :return:
  116. raw: mne.io.RawArray
  117. """
  118. f = {
  119. 'data': os.path.join(data_dir, 'data.bdf'),
  120. 'evt': os.path.join(data_dir, 'evt.bdf'),
  121. 'info': os.path.join(data_dir, 'recordInformation.json')
  122. }
  123. # read json
  124. with open(f['info'], 'r') as json_file:
  125. record_info = json.load(json_file)
  126. start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
  127. sfreq = record_info['SampleRate']
  128. # read data
  129. f_data = pyedflib.EdfReader(f['data'])
  130. ch_names = f_data.getSignalLabels()
  131. data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6
  132. info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
  133. raw = mne.io.RawArray(data, info)
  134. # read event
  135. try:
  136. f_evt = pyedflib.EdfReader(f['evt'])
  137. onset, duration, content = f_evt.readAnnotations()
  138. onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
  139. onset = (onset * sfreq).astype(np.int64)
  140. duration = (np.array(duration) * sfreq).astype(np.int64)
  141. event_mapping = {c: i for i, c in enumerate(np.unique(content))}
  142. event_ids = [event_mapping[i] for i in content]
  143. events = np.stack((onset, duration, event_ids), axis=1)
  144. annotations = mne.annotations_from_events(events, sfreq)
  145. raw.set_annotations(annotations)
  146. except OSError:
  147. pass
  148. return raw