123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- import numpy as np
- import os
- import json
- import mne
- import glob
- import pyedflib
- from .utils import upsample_events
- FINGERMODEL_IDS = {
- 'rest': 0,
- 'cylinder': 1,
- 'ball': 2,
- 'flex': 3,
- 'double': 4,
- 'treble': 5
- }
- def raw_preprocessing(data_root, session_paths:dict,
- upsampled_epoch_length=1.,
- ori_epoch_length=5,
- rename_event=True):
- """
- Params:
- subj_root:
- session_paths: dict of lists
- upsampled_epoch_length:
- ori_epoch_length (int or 'varied'): original epoch length in second
- rename_event (True, use unified event label, False use original)
- """
- raws_loaded = load_sessions(data_root, session_paths)
- # process event
- raws = []
- for (finger_model, raw) in raws_loaded:
- fs = raw.info['sfreq']
- events, _ = mne.events_from_annotations(raw)
- mov_trial_ind = [2, 3]
- rest_trial_ind = [4]
- if not rename_event:
- mov_trial_ind = [finger_model]
- rest_trial_ind = [0]
-
- if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
- trial_duration = ori_epoch_length
- elif ori_epoch_length == 'varied':
- trial_duration = None
- else:
- raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
- events = reconstruct_events(events, fs, finger_model,
- mov_trial_ind=mov_trial_ind,
- rest_trial_ind=rest_trial_ind,
- trial_duration=trial_duration,
- use_original_label=not rename_event)
-
- events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length))
- annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
- raw.set_annotations(annotations)
- raws.append(raw)
- raws = mne.concatenate_raws(raws)
- raws.load_data()
- # filter 50Hz
- raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
- return raws
- def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
- """重构出事件序列中的单独运动事件
- Args:
- fs: int
- finger_model:
- """
- # Trial duration are fixed to be ? seconds.
- # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
- # ignore initialRest
- # extract trials
- deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
- deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
- trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
- events_new = events[trials_ind_deduplicated]
- if trial_duration is None:
- events_new[:-1, 1] = np.diff(events_new[:, 0])
- events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
- else:
- events_new[:, 1] = int(trial_duration * fs)
- events_final = events_new.copy()
- if not use_original_label and finger_model is not None:
- # process mov
- ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
- events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model]
- # process rest
- ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
- events_final[ind_rest, 2] = 0
- return events_final
- def load_sessions(data_root, session_names: dict):
- # return raws for different finger models on an interleaved manner
- raw_cnt = sum(len(session_names[k]) for k in session_names)
- raws = []
- i = 0
- while i < raw_cnt:
- for finger_model in session_names.keys():
- try:
- s = session_names[finger_model].pop(0)
- i += 1
- except IndexError:
- continue
- if glob.glob(os.path.join(data_root, s, 'evt.bdf')):
- # neo format
- raw = load_neuracle(os.path.join(data_root, s))
- else:
- # kraken format
- data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
- raw = mne.io.read_raw_bdf(data_file)
- raws.append((finger_model, raw))
- return raws
- def load_neuracle(data_dir, data_type='ecog'):
- """
- neuracle file loader
- :param
- data_dir: root data dir for the experiment
- sfreq:
- data_type:
- :return:
- raw: mne.io.RawArray
- """
- f = {
- 'data': os.path.join(data_dir, 'data.bdf'),
- 'evt': os.path.join(data_dir, 'evt.bdf'),
- 'info': os.path.join(data_dir, 'recordInformation.json')
- }
- # read json
- with open(f['info'], 'r') as json_file:
- record_info = json.load(json_file)
- start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp']
- sfreq = record_info['SampleRate']
- # read data
- f_data = pyedflib.EdfReader(f['data'])
- ch_names = f_data.getSignalLabels()
- data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6
- info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names))
- raw = mne.io.RawArray(data, info)
- # read event
- try:
- f_evt = pyedflib.EdfReader(f['evt'])
- onset, duration, content = f_evt.readAnnotations()
- onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
- onset = (onset * sfreq).astype(np.int64)
- duration = (np.array(duration) * sfreq).astype(np.int64)
- event_mapping = {c: i for i, c in enumerate(np.unique(content))}
- event_ids = [event_mapping[i] for i in content]
- events = np.stack((onset, duration, event_ids), axis=1)
-
- annotations = mne.annotations_from_events(events, sfreq)
- raw.set_annotations(annotations)
- except OSError:
- pass
- return raw
|