neo.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import numpy as np
  2. import os
  3. import json
  4. import mne
  5. import glob
  6. import pyedflib
  7. from .utils import upsample_events
  8. from settings.config import settings
  9. FINGERMODEL_IDS = settings.FINGERMODEL_IDS
  10. FINGERMODEL_IDS_INVERSE = settings.FINGERMODEL_IDS_INVERSE
  11. CONFIG_INFO = settings.CONFIG_INFO
  12. def raw_loader(data_root, session_paths:dict,
  13. reref_method='monopolar',
  14. upsampled_epoch_length=1.,
  15. ori_epoch_length=5):
  16. """
  17. Params:
  18. data_root:
  19. session_paths: dict of lists
  20. reref_method (str): rereference method: monopolar, average, or bipolar
  21. upsampled_epoch_length (None or float): None: do not do upsampling
  22. ori_epoch_length (int or 'varied'): original epoch length in second
  23. """
  24. raws_loaded = load_sessions(data_root, session_paths, reref_method)
  25. # process event
  26. raws = []
  27. event_id = {}
  28. for (finger_model, raw) in raws_loaded:
  29. fs = raw.info['sfreq']
  30. {d: int(d) for d in np.unique(raw.annotations.description)}
  31. events, _ = mne.events_from_annotations(raw, event_id={d: int(d) for d in np.unique(raw.annotations.description)})
  32. event_id = event_id | {FINGERMODEL_IDS_INVERSE[int(d)]: int(d) for d in np.unique(raw.annotations.description)}
  33. if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
  34. trial_duration = ori_epoch_length
  35. elif ori_epoch_length == 'varied':
  36. trial_duration = None
  37. elif isinstance(ori_epoch_length, dict):
  38. trial_duration = ori_epoch_length
  39. else:
  40. raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
  41. events = reconstruct_events(events, fs,
  42. trial_duration=trial_duration)
  43. if upsampled_epoch_length is not None:
  44. events = upsample_events(events, int(fs * upsampled_epoch_length))
  45. event_desc = {e: FINGERMODEL_IDS_INVERSE[e] for e in np.unique(events[:, 2])}
  46. annotations = mne.annotations_from_events(events, fs, event_desc)
  47. raw.set_annotations(annotations)
  48. raws.append(raw)
  49. raws = mne.concatenate_raws(raws)
  50. raws.load_data()
  51. return raws, event_id
  52. def reref(raw, method='average'):
  53. if method == 'average':
  54. return raw.set_eeg_reference('average')
  55. elif method == 'bipolar':
  56. anode = CONFIG_INFO['strips'][0] + CONFIG_INFO['strips'][1][1:][::-1]
  57. cathode = CONFIG_INFO['strips'][0][1:] + CONFIG_INFO['strips'][1][::-1]
  58. return mne.set_bipolar_reference(raw, anode, cathode)
  59. elif method == 'monopolar':
  60. return raw
  61. else:
  62. raise ValueError(f'Rereference method unacceptable, got {str(method)}, expect "monopolar" or "average" or "bipolar"')
  63. def preprocessing(raw, reref_method='monopolar'):
  64. raw.load_data()
  65. raw = reref(raw, reref_method)
  66. # high pass
  67. raw = raw.filter(1, None)
  68. # filter 50Hz
  69. raw = raw.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
  70. return raw
  71. def reconstruct_events(events, fs, trial_duration=5):
  72. """重构出事件序列中的单独运动事件
  73. Args:
  74. events (np.ndarray):
  75. fs (float):
  76. trial_duration (float or None or dict): None means variable epoch length, dict means there are different trial durations for different trials
  77. """
  78. # Trial duration are fixed to be ? seconds.
  79. # extract trials
  80. trials_ind_deduplicated = np.flatnonzero(np.diff(events[:, 2], prepend=0) != 0)
  81. events_new = events[trials_ind_deduplicated]
  82. if trial_duration is None:
  83. events_new[:-1, 1] = np.diff(events_new[:, 0])
  84. events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
  85. elif isinstance(trial_duration, dict):
  86. for e in trial_duration.keys():
  87. events_new[events_new[:, 2] == e, 1] = int(trial_duration[e] * fs)
  88. else:
  89. events_new[:, 1] = int(trial_duration * fs)
  90. return events_new
  91. def load_sessions(data_root, session_names: dict, reref_method='monopolar'):
  92. # return raws for different finger models on an interleaved manner
  93. raw_cnt = sum(len(session_names[k]) for k in session_names)
  94. raws = []
  95. i = 0
  96. while i < raw_cnt:
  97. for finger_model in session_names.keys():
  98. try:
  99. s = session_names[finger_model].pop(0)
  100. i += 1
  101. except IndexError:
  102. continue
  103. if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
  104. # neo format
  105. raw = load_neuracle(os.path.join(data_root, s))
  106. else:
  107. # kraken format
  108. data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
  109. raw = mne.io.read_raw_bdf(data_file)
  110. # preprocess raw
  111. raw = preprocessing(raw, reref_method)
  112. # append list
  113. raws.append((finger_model, raw))
  114. return raws
  115. def load_neuracle(data_dir, data_type='ecog'):
  116. """
  117. neuracle file loader
  118. :param
  119. data_dir: root data dir for the experiment
  120. sfreq:
  121. data_type:
  122. :return:
  123. raw: mne.io.RawArray
  124. """
  125. f = {
  126. 'data': os.path.join(data_dir, 'data.bdf'),
  127. 'evt': os.path.join(data_dir, 'evt.bdf'),
  128. 'info': os.path.join(data_dir, 'recordInformation.json')
  129. }
  130. # read json
  131. with open(f['info'], 'r') as json_file:
  132. record_info = json.load(json_file)
  133. start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
  134. sfreq = record_info['SampleRate']
  135. # read data
  136. f_data = pyedflib.EdfReader(f['data'])
  137. ch_names = f_data.getSignalLabels()
  138. data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6 # to Volt
  139. info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
  140. raw = mne.io.RawArray(data, info)
  141. # read event
  142. try:
  143. f_evt = pyedflib.EdfReader(f['evt'])
  144. onset, duration, content = f_evt.readAnnotations()
  145. onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
  146. onset = (onset * sfreq).astype(np.int64)
  147. try:
  148. content = content.astype(np.int64) # use original event code
  149. except ValueError:
  150. event_mapping = {c: i for i, c in enumerate(np.unique(content))}
  151. content = [event_mapping[i] for i in content]
  152. duration = (np.array(duration) * sfreq).astype(np.int64)
  153. events = np.stack((onset, duration, content), axis=1)
  154. annotations = mne.annotations_from_events(events, sfreq)
  155. raw.set_annotations(annotations)
  156. except OSError:
  157. pass
  158. return raw