import numpy as np import os import json import mne import glob import pyedflib from .utils import upsample_events FINGERMODEL_IDS = { 'rest': 0, 'cylinder': 1, 'ball': 2, 'flex': 3, 'double': 4, 'treble': 5 } def raw_preprocessing(data_root, session_paths:dict, upsampled_epoch_length=1., ori_epoch_length=5, rename_event=True): """ Params: subj_root: session_paths: dict of lists upsampled_epoch_length: ori_epoch_length (int or 'varied'): original epoch length in second rename_event (True, use unified event label, False use original) """ raws_loaded = load_sessions(data_root, session_paths) # process event raws = [] for (finger_model, raw) in raws_loaded: fs = raw.info['sfreq'] events, _ = mne.events_from_annotations(raw) mov_trial_ind = [2, 3] rest_trial_ind = [4] if not rename_event: mov_trial_ind = [finger_model] rest_trial_ind = [0] if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float): trial_duration = ori_epoch_length elif ori_epoch_length == 'varied': trial_duration = None else: raise ValueError(f'Unsupported epoch_length {ori_epoch_length}') events = reconstruct_events(events, fs, finger_model, mov_trial_ind=mov_trial_ind, rest_trial_ind=rest_trial_ind, trial_duration=trial_duration, use_original_label=not rename_event) events_upsampled = upsample_events(events, int(fs * upsampled_epoch_length)) annotations = mne.annotations_from_events(events_upsampled, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'}) raw.set_annotations(annotations) raws.append(raw) raws = mne.concatenate_raws(raws) raws.load_data() # filter 50Hz raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False) return raws def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False): """重构出事件序列中的单独运动事件 Args: fs: int finger_model: """ # Trial duration are fixed to be ? seconds. # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4 # ignore initialRest # extract trials deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1 deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1 trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest)) events_new = events[trials_ind_deduplicated] if trial_duration is None: events_new[:-1, 1] = np.diff(events_new[:, 0]) events_new[-1, 1] = events[-1, 0] - events_new[-1, 0] else: events_new[:, 1] = int(trial_duration * fs) events_final = events_new.copy() if not use_original_label and finger_model is not None: # process mov ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind)) events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model] # process rest ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind)) events_final[ind_rest, 2] = 0 return events_final def load_sessions(data_root, session_names: dict): # return raws for different finger models on an interleaved manner raw_cnt = sum(len(session_names[k]) for k in session_names) raws = [] i = 0 while i < raw_cnt: for finger_model in session_names.keys(): try: s = session_names[finger_model].pop(0) i += 1 except IndexError: continue if glob.glob(os.path.join(data_root, s, 'evt.bdf')): # neo format raw = load_neuracle(os.path.join(data_root, s)) else: # kraken format data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0] raw = mne.io.read_raw_bdf(data_file) raws.append((finger_model, raw)) return raws def load_neuracle(data_dir, data_type='ecog'): """ neuracle file loader :param data_dir: root data dir for the experiment sfreq: data_type: :return: raw: mne.io.RawArray """ f = { 'data': os.path.join(data_dir, 'data.bdf'), 'evt': os.path.join(data_dir, 'evt.bdf'), 'info': os.path.join(data_dir, 'recordInformation.json') } # read json with open(f['info'], 'r') as json_file: record_info = json.load(json_file) start_time_point = record_info['DataFileInformations'][0]['BeginTimeStamp'] sfreq = record_info['SampleRate'] # read data f_data = pyedflib.EdfReader(f['data']) ch_names = f_data.getSignalLabels() data = np.array([f_data.readSignal(i) for i in range(f_data.signals_in_file)]) * 1e-6 info = mne.create_info(ch_names, sfreq, [data_type] * len(ch_names)) raw = mne.io.RawArray(data, info) # read event try: f_evt = pyedflib.EdfReader(f['evt']) onset, duration, content = f_evt.readAnnotations() onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point onset = (onset * sfreq).astype(np.int64) duration = (np.array(duration) * sfreq).astype(np.int64) event_mapping = {c: i for i, c in enumerate(np.unique(content))} event_ids = [event_mapping[i] for i in content] events = np.stack((onset, duration, event_ids), axis=1) annotations = mne.annotations_from_events(events, sfreq) raw.set_annotations(annotations) except OSError: pass return raw