2
0

neo.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import numpy as np
  2. import os
  3. import json
  4. import mne
  5. import glob
  6. import pyedflib
  7. from .utils import upsample_events
  8. from settings.config import settings
  9. FINGERMODEL_IDS = settings.FINGERMODEL_IDS
  10. FINGERMODEL_IDS_INVERSE = settings.FINGERMODEL_IDS_INVERSE
  11. CONFIG_INFO = settings.CONFIG_INFO
  12. def raw_loader(data_root, session_paths:dict,
  13. reref_method='monopolar',
  14. use_ori_events=False,
  15. upsampled_epoch_length=1.,
  16. ori_epoch_length=5):
  17. """
  18. Params:
  19. data_root:
  20. session_paths: dict of lists
  21. reref_method (str): rereference method: monopolar, average, or bipolar
  22. upsampled_epoch_length (None or float): None: do not do upsampling
  23. ori_epoch_length (int, dict, or 'varied'): original epoch length in second
  24. """
  25. raws_loaded = load_sessions(data_root, session_paths, reref_method)
  26. # process event
  27. raws = []
  28. event_id = {}
  29. for (finger_model, raw) in raws_loaded:
  30. fs = raw.info['sfreq']
  31. {d: int(d) for d in np.unique(raw.annotations.description)}
  32. events, _ = mne.events_from_annotations(raw, event_id={d: int(d) for d in np.unique(raw.annotations.description)})
  33. event_id = event_id | {FINGERMODEL_IDS_INVERSE[int(d)]: int(d) for d in np.unique(raw.annotations.description)}
  34. if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
  35. trial_duration = ori_epoch_length
  36. elif ori_epoch_length == 'varied':
  37. trial_duration = None
  38. elif isinstance(ori_epoch_length, dict):
  39. trial_duration = ori_epoch_length
  40. else:
  41. raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
  42. events = reconstruct_events(events, fs,
  43. use_ori_events=use_ori_events,
  44. trial_duration=trial_duration)
  45. if upsampled_epoch_length is not None:
  46. events = upsample_events(events, int(fs * upsampled_epoch_length))
  47. event_desc = {e: FINGERMODEL_IDS_INVERSE[e] for e in np.unique(events[:, 2])}
  48. annotations = mne.annotations_from_events(events, fs, event_desc)
  49. raw.set_annotations(annotations)
  50. raws.append(raw)
  51. raws = mne.concatenate_raws(raws)
  52. raws.load_data()
  53. return raws, event_id
  54. def reref(raw, method='average'):
  55. if method == 'average':
  56. return raw.set_eeg_reference('average')
  57. elif method == 'bipolar':
  58. anode = CONFIG_INFO['strips'][0] + CONFIG_INFO['strips'][1][1:][::-1]
  59. cathode = CONFIG_INFO['strips'][0][1:] + CONFIG_INFO['strips'][1][::-1]
  60. return mne.set_bipolar_reference(raw, anode, cathode)
  61. elif method == 'monopolar':
  62. return raw
  63. else:
  64. raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')
  65. def preprocessing(raw, reref_method='monopolar', keeptimes=None):
  66. # cut by the first and last annotations
  67. annotation_onset, annotation_offset = raw.annotations.onset[0], raw.annotations.onset[-1]
  68. if keeptimes is None:
  69. keeptimes = 10.
  70. tmin, tmax = max(annotation_onset - keeptimes, raw.times[0]), min(annotation_offset + keeptimes, raw.times[-1])
  71. # rebuilt the raw
  72. # MNE的crop函数会导致annotation错乱,只能重建raw object
  73. new_annotations = mne.Annotations(onset=raw.annotations.onset - tmin,
  74. duration=raw.annotations.duration,
  75. description=raw.annotations.description)
  76. info = raw.info
  77. fs = info['sfreq']
  78. data = raw.get_data()
  79. # crop data
  80. data = data[..., int(tmin * fs):int(tmax * fs)]
  81. raw = mne.io.RawArray(data, info)
  82. raw.set_annotations(new_annotations)
  83. # do signal preprocessing
  84. raw.load_data()
  85. raw = reref(raw, reref_method)
  86. # high pass
  87. raw = raw.filter(1, None)
  88. # filter 50Hz
  89. raw = raw.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
  90. return raw
  91. def reconstruct_events(events, fs, trial_duration=5, use_ori_events=False):
  92. """重构出事件序列中的单独运动事件
  93. Args:
  94. events (np.ndarray):
  95. fs (float):
  96. trial_duration (float or None or dict): None means variable epoch length, dict means there are different trial durations for different trials
  97. use_ori_events: skip deduplication
  98. """
  99. # Trial duration are fixed to be ? seconds.
  100. # extract trials
  101. if use_ori_events:
  102. events_new = events.copy()
  103. else:
  104. trials_ind_deduplicated = np.flatnonzero(np.diff(events[:, 2], prepend=0) != 0)
  105. events_new = events[trials_ind_deduplicated]
  106. if trial_duration is None:
  107. events_new[:-1, 1] = np.diff(events_new[:, 0])
  108. events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
  109. elif isinstance(trial_duration, dict):
  110. for e in trial_duration.keys():
  111. if isinstance(trial_duration[e], list):
  112. onset, offset = trial_duration[e]
  113. else:
  114. onset, offset = 0., trial_duration[e]
  115. duration = offset - onset
  116. events_new[events_new[:, 2] == e, 0] += int(onset * fs)
  117. events_new[events_new[:, 2] == e, 1] = int(duration * fs)
  118. else:
  119. events_new[:, 1] = int(trial_duration * fs)
  120. return events_new
  121. def load_sessions(data_root, session_names: dict, reref_method='monopolar'):
  122. # return raws for different finger models on an interleaved manner
  123. raw_cnt = sum(len(session_names[k]) for k in session_names)
  124. raws = []
  125. i = 0
  126. while i < raw_cnt:
  127. for finger_model in session_names.keys():
  128. try:
  129. s = session_names[finger_model].pop(0)
  130. i += 1
  131. except IndexError:
  132. continue
  133. # load raw
  134. raw = load_neuracle(os.path.join(data_root, s))
  135. # preprocess raw
  136. raw = preprocessing(raw, reref_method)
  137. # append list
  138. raws.append((finger_model, raw))
  139. return raws
  140. def load_neuracle(data_dir, data_type='ecog'):
  141. """
  142. neuracle file loader
  143. :param
  144. data_dir: root data dir for the experiment
  145. sfreq:
  146. data_type:
  147. :return:
  148. raw: mne.io.RawArray
  149. """
  150. f = {
  151. 'data': os.path.join(data_dir, 'data.bdf'),
  152. 'evt': os.path.join(data_dir, 'evt.bdf'),
  153. 'info': os.path.join(data_dir, 'recordInformation.json')
  154. }
  155. # read json
  156. with open(f['info'], 'r') as json_file:
  157. record_info = json.load(json_file)
  158. start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
  159. sfreq = record_info['SampleRate']
  160. # read data
  161. f_data = pyedflib.EdfReader(f['data'])
  162. ch_names = f_data.getSignalLabels()
  163. data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6 # to Volt
  164. info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
  165. raw = mne.io.RawArray(data, info)
  166. # read event
  167. try:
  168. f_evt = pyedflib.EdfReader(f['evt'])
  169. onset, duration, content = f_evt.readAnnotations()
  170. onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
  171. onset = (onset * sfreq).astype(np.int64)
  172. try:
  173. content = content.astype(np.int64) # use original event code
  174. except ValueError:
  175. event_mapping = {c: i + 1 for i, c in enumerate(np.unique(content))}
  176. content = [event_mapping[i] for i in content]
  177. duration = (np.array(duration) * sfreq).astype(np.int64)
  178. events = np.stack((onset, duration, content), axis=1)
  179. annotations = mne.annotations_from_events(events, sfreq)
  180. raw.set_annotations(annotations)
  181. except OSError:
  182. pass
  183. return raw