1
0

neo.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import numpy as np
  2. import os
  3. import json
  4. import mne
  5. import glob
  6. import pyedflib
  7. from scipy import signal
  8. from .utils import upsample_events
  9. from settings.config import settings
  10. FINGERMODEL_IDS = settings.FINGERMODEL_IDS
  11. CONFIG_INFO = settings.CONFIG_INFO
  12. def raw_preprocessing(data_root, session_paths:dict,
  13. upsampled_epoch_length=1.,
  14. ori_epoch_length=5,
  15. unify_label=True,
  16. mov_trial_ind=[2, 3],
  17. rest_trial_ind=[4]):
  18. """
  19. Params:
  20. subj_root:
  21. session_paths: dict of lists
  22. upsampled_epoch_length:
  23. ori_epoch_length (int or 'varied'): original epoch length in second
  24. unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
  25. mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
  26. rest_trial_ind: only used when unify_label == True,
  27. """
  28. raws_loaded = load_sessions(data_root, session_paths)
  29. # process event
  30. raws = []
  31. for (finger_model, raw) in raws_loaded:
  32. fs = raw.info['sfreq']
  33. events, _ = mne.events_from_annotations(raw)
  34. if not unify_label:
  35. mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
  36. rest_trial_ind = [0]
  37. if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
  38. trial_duration = ori_epoch_length
  39. elif ori_epoch_length == 'varied':
  40. trial_duration = None
  41. else:
  42. raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
  43. events = reconstruct_events(events, fs, finger_model,
  44. mov_trial_ind=mov_trial_ind,
  45. rest_trial_ind=rest_trial_ind,
  46. trial_duration=trial_duration,
  47. use_original_label=not unify_label)
  48. events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
  49. annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
  50. raw.set_annotations(annotations)
  51. raws.append(raw)
  52. raws = mne.concatenate_raws(raws)
  53. raws.load_data()
  54. # common average
  55. raws.set_eeg_reference('average')
  56. # high pass
  57. raws = raws.filter(1, None)
  58. # filter 50Hz
  59. raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
  60. return raws
  61. def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
  62. """重构出事件序列中的单独运动事件
  63. Args:
  64. fs: int
  65. finger_model:
  66. """
  67. # Trial duration are fixed to be ? seconds.
  68. # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
  69. # ignore initialRest
  70. # extract trials
  71. deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
  72. deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
  73. trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
  74. events_new = events[trials_ind_deduplicated]
  75. if trial_duration is None:
  76. events_new[:-1, 1] = np.diff(events_new[:, 0])
  77. events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
  78. else:
  79. events_new[:, 1] = int(trial_duration * fs)
  80. events_final = events_new.copy()
  81. if (not use_original_label) and (finger_model is not None):
  82. # process mov
  83. ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
  84. events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model]
  85. # process rest
  86. ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
  87. events_final[ind_rest, 2] = 0
  88. return events_final
  89. def load_sessions(data_root, session_names: dict):
  90. # return raws for different finger models on an interleaved manner
  91. raw_cnt = sum(len(session_names[k]) for k in session_names)
  92. raws = []
  93. i = 0
  94. while i < raw_cnt:
  95. for finger_model in session_names.keys():
  96. try:
  97. s = session_names[finger_model].pop(0)
  98. i += 1
  99. except IndexError:
  100. continue
  101. if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
  102. # neo format
  103. raw = load_neuracle(os.path.join(data_root, s))
  104. else:
  105. # kraken format
  106. data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
  107. raw = mne.io.read_raw_bdf(data_file)
  108. raws.append((finger_model, raw))
  109. return raws
  110. def load_neuracle(data_dir, data_type='ecog'):
  111. """
  112. neuracle file loader
  113. :param
  114. data_dir: root data dir for the experiment
  115. sfreq:
  116. data_type:
  117. :return:
  118. raw: mne.io.RawArray
  119. """
  120. f = {
  121. 'data': os.path.join(data_dir, 'data.bdf'),
  122. 'evt': os.path.join(data_dir, 'evt.bdf'),
  123. 'info': os.path.join(data_dir, 'recordInformation.json')
  124. }
  125. # read json
  126. with open(f['info'], 'r') as json_file:
  127. record_info = json.load(json_file)
  128. start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
  129. sfreq = record_info['SampleRate']
  130. # read data
  131. f_data = pyedflib.EdfReader(f['data'])
  132. ch_names = f_data.getSignalLabels()
  133. data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6
  134. info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
  135. raw = mne.io.RawArray(data, info)
  136. # read event
  137. try:
  138. f_evt = pyedflib.EdfReader(f['evt'])
  139. onset, duration, content = f_evt.readAnnotations()
  140. onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
  141. onset = (onset * sfreq).astype(np.int64)
  142. duration = (np.array(duration) * sfreq).astype(np.int64)
  143. event_mapping = {c: i for i, c in enumerate(np.unique(content))}
  144. event_ids = [event_mapping[i] for i in content]
  145. events = np.stack((onset, duration, event_ids), axis=1)
  146. annotations = mne.annotations_from_events(events, sfreq)
  147. raw.set_annotations(annotations)
  148. except OSError:
  149. pass
  150. return raw