neo.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import numpy as np
  2. import os
  3. import json
  4. import mne
  5. import glob
  6. import pyedflib
  7. from scipy import signal
  8. from .utils import upsample_events
  9. from settings.config import settings
  10. FINGERMODEL_IDS = settings.FINGERMODEL_IDS
  11. CONFIG_INFO = settings.CONFIG_INFO
  12. def raw_preprocessing(data_root, session_paths:dict,
  13. do_rereference=True,
  14. upsampled_epoch_length=1.,
  15. ori_epoch_length=5,
  16. unify_label=True,
  17. mov_trial_ind=[2, 3],
  18. rest_trial_ind=[4]):
  19. """
  20. Params:
  21. subj_root:
  22. session_paths: dict of lists
  23. do_rereference (bool): do common average rereference or not
  24. upsampled_epoch_length (None or float): None: do not do upsampling
  25. ori_epoch_length (int or 'varied'): original epoch length in second
  26. unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
  27. mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
  28. rest_trial_ind: only used when unify_label == True,
  29. """
  30. raws_loaded = load_sessions(data_root, session_paths)
  31. # process event
  32. raws = []
  33. for (finger_model, raw) in raws_loaded:
  34. fs = raw.info['sfreq']
  35. events, _ = mne.events_from_annotations(raw)
  36. if not unify_label:
  37. mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
  38. rest_trial_ind = [0]
  39. if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
  40. trial_duration = ori_epoch_length
  41. elif ori_epoch_length == 'varied':
  42. trial_duration = None
  43. else:
  44. raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
  45. events = reconstruct_events(events, fs, finger_model,
  46. mov_trial_ind=mov_trial_ind,
  47. rest_trial_ind=rest_trial_ind,
  48. trial_duration=trial_duration,
  49. use_original_label=not unify_label)
  50. if upsampled_epoch_length is not None:
  51. events = upsample_events(events, int(fs * upsampled_epoch_length))
  52. annotations = mne.annotations_from_events(events, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
  53. raw.set_annotations(annotations)
  54. raws.append(raw)
  55. raws = mne.concatenate_raws(raws)
  56. raws.load_data()
  57. if do_rereference:
  58. # common average
  59. raws.set_eeg_reference('average')
  60. # high pass
  61. raws = raws.filter(1, None)
  62. # filter 50Hz
  63. raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
  64. return raws
  65. def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
  66. """重构出事件序列中的单独运动事件
  67. Args:
  68. fs: int
  69. finger_model:
  70. """
  71. # Trial duration are fixed to be ? seconds.
  72. # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
  73. # ignore initialRest
  74. # extract trials
  75. deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
  76. deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
  77. trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
  78. events_new = events[trials_ind_deduplicated]
  79. if trial_duration is None:
  80. events_new[:-1, 1] = np.diff(events_new[:, 0])
  81. events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
  82. else:
  83. events_new[:, 1] = int(trial_duration * fs)
  84. events_final = events_new.copy()
  85. if (not use_original_label) and (finger_model is not None):
  86. # process mov
  87. ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
  88. events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model]
  89. # process rest
  90. ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
  91. events_final[ind_rest, 2] = 0
  92. return events_final
  93. def load_sessions(data_root, session_names: dict):
  94. # return raws for different finger models on an interleaved manner
  95. raw_cnt = sum(len(session_names[k]) for k in session_names)
  96. raws = []
  97. i = 0
  98. while i < raw_cnt:
  99. for finger_model in session_names.keys():
  100. try:
  101. s = session_names[finger_model].pop(0)
  102. i += 1
  103. except IndexError:
  104. continue
  105. if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
  106. # neo format
  107. raw = load_neuracle(os.path.join(data_root, s))
  108. else:
  109. # kraken format
  110. data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
  111. raw = mne.io.read_raw_bdf(data_file)
  112. raws.append((finger_model, raw))
  113. return raws
  114. def load_neuracle(data_dir, data_type='ecog'):
  115. """
  116. neuracle file loader
  117. :param
  118. data_dir: root data dir for the experiment
  119. sfreq:
  120. data_type:
  121. :return:
  122. raw: mne.io.RawArray
  123. """
  124. f = {
  125. 'data': os.path.join(data_dir, 'data.bdf'),
  126. 'evt': os.path.join(data_dir, 'evt.bdf'),
  127. 'info': os.path.join(data_dir, 'recordInformation.json')
  128. }
  129. # read json
  130. with open(f['info'], 'r') as json_file:
  131. record_info = json.load(json_file)
  132. start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
  133. sfreq = record_info['SampleRate']
  134. # read data
  135. f_data = pyedflib.EdfReader(f['data'])
  136. ch_names = f_data.getSignalLabels()
  137. data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6 # to Volt
  138. info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
  139. raw = mne.io.RawArray(data, info)
  140. # read event
  141. try:
  142. f_evt = pyedflib.EdfReader(f['evt'])
  143. onset, duration, content = f_evt.readAnnotations()
  144. onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
  145. onset = (onset * sfreq).astype(np.int64)
  146. duration = (np.array(duration) * sfreq).astype(np.int64)
  147. event_mapping = {c: i for i, c in enumerate(np.unique(content))}
  148. event_ids = [event_mapping[i] for i in content]
  149. events = np.stack((onset, duration, event_ids), axis=1)
  150. annotations = mne.annotations_from_events(events, sfreq)
  151. raw.set_annotations(annotations)
  152. except OSError:
  153. pass
  154. return raw