neo.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import numpy as np
  2. import os
  3. import json
  4. import mne
  5. import glob
  6. import pyedflib
  7. from scipy import signal
  8. from .utils import upsample_events
  9. from settings.config import settings
  10. FINGERMODEL_IDS = settings.FINGERMODEL_IDS
  11. CONFIG_INFO = settings.CONFIG_INFO
  12. def raw_preprocessing(data_root, session_paths:dict,
  13. do_rereference=True,
  14. upsampled_epoch_length=1.,
  15. ori_epoch_length=5,
  16. unify_label=True,
  17. mov_trial_ind=[2, 3],
  18. rest_trial_ind=[4]):
  19. """
  20. Params:
  21. subj_root:
  22. session_paths: dict of lists
  23. do_rereference (bool): do common average rereference or not
  24. upsampled_epoch_length:
  25. ori_epoch_length (int or 'varied'): original epoch length in second
  26. unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
  27. mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
  28. rest_trial_ind: only used when unify_label == True,
  29. """
  30. raws_loaded = load_sessions(data_root, session_paths)
  31. # process event
  32. raws = []
  33. for (finger_model, raw) in raws_loaded:
  34. fs = raw.info['sfreq']
  35. events, _ = mne.events_from_annotations(raw)
  36. if not unify_label:
  37. mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
  38. rest_trial_ind = [0]
  39. if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
  40. trial_duration = ori_epoch_length
  41. elif ori_epoch_length == 'varied':
  42. trial_duration = None
  43. else:
  44. raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
  45. events = reconstruct_events(events, fs, finger_model,
  46. mov_trial_ind=mov_trial_ind,
  47. rest_trial_ind=rest_trial_ind,
  48. trial_duration=trial_duration,
  49. use_original_label=not unify_label)
  50. events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
  51. annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
  52. raw.set_annotations(annotations)
  53. raws.append(raw)
  54. raws = mne.concatenate_raws(raws)
  55. raws.load_data()
  56. if do_rereference:
  57. # common average
  58. raws.set_eeg_reference('average')
  59. # high pass
  60. raws = raws.filter(1, None)
  61. # filter 50Hz
  62. raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
  63. return raws
  64. def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
  65. """重构出事件序列中的单独运动事件
  66. Args:
  67. fs: int
  68. finger_model:
  69. """
  70. # Trial duration are fixed to be ? seconds.
  71. # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
  72. # ignore initialRest
  73. # extract trials
  74. deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
  75. deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
  76. trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
  77. events_new = events[trials_ind_deduplicated]
  78. if trial_duration is None:
  79. events_new[:-1, 1] = np.diff(events_new[:, 0])
  80. events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
  81. else:
  82. events_new[:, 1] = int(trial_duration * fs)
  83. events_final = events_new.copy()
  84. if (not use_original_label) and (finger_model is not None):
  85. # process mov
  86. ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
  87. events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model]
  88. # process rest
  89. ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
  90. events_final[ind_rest, 2] = 0
  91. return events_final
  92. def load_sessions(data_root, session_names: dict):
  93. # return raws for different finger models on an interleaved manner
  94. raw_cnt = sum(len(session_names[k]) for k in session_names)
  95. raws = []
  96. i = 0
  97. while i < raw_cnt:
  98. for finger_model in session_names.keys():
  99. try:
  100. s = session_names[finger_model].pop(0)
  101. i += 1
  102. except IndexError:
  103. continue
  104. if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
  105. # neo format
  106. raw = load_neuracle(os.path.join(data_root, s))
  107. else:
  108. # kraken format
  109. data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
  110. raw = mne.io.read_raw_bdf(data_file)
  111. raws.append((finger_model, raw))
  112. return raws
  113. def load_neuracle(data_dir, data_type='ecog'):
  114. """
  115. neuracle file loader
  116. :param
  117. data_dir: root data dir for the experiment
  118. sfreq:
  119. data_type:
  120. :return:
  121. raw: mne.io.RawArray
  122. """
  123. f = {
  124. 'data': os.path.join(data_dir, 'data.bdf'),
  125. 'evt': os.path.join(data_dir, 'evt.bdf'),
  126. 'info': os.path.join(data_dir, 'recordInformation.json')
  127. }
  128. # read json
  129. with open(f['info'], 'r') as json_file:
  130. record_info = json.load(json_file)
  131. start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
  132. sfreq = record_info['SampleRate']
  133. # read data
  134. f_data = pyedflib.EdfReader(f['data'])
  135. ch_names = f_data.getSignalLabels()
  136. data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6 # to Volt
  137. info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
  138. raw = mne.io.RawArray(data, info)
  139. # read event
  140. try:
  141. f_evt = pyedflib.EdfReader(f['evt'])
  142. onset, duration, content = f_evt.readAnnotations()
  143. onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
  144. onset = (onset * sfreq).astype(np.int64)
  145. duration = (np.array(duration) * sfreq).astype(np.int64)
  146. event_mapping = {c: i for i, c in enumerate(np.unique(content))}
  147. event_ids = [event_mapping[i] for i in content]
  148. events = np.stack((onset, duration, event_ids), axis=1)
  149. annotations = mne.annotations_from_events(events, sfreq)
  150. raw.set_annotations(annotations)
  151. except OSError:
  152. pass
  153. return raw