|
@@ -4,12 +4,11 @@ import json
|
|
|
import mne
|
|
|
import glob
|
|
|
import pyedflib
|
|
|
-from scipy import signal
|
|
|
from .utils import upsample_events
|
|
|
from settings.config import settings
|
|
|
|
|
|
-
|
|
|
FINGERMODEL_IDS = settings.FINGERMODEL_IDS
|
|
|
+FINGERMODEL_IDS_INVERSE = settings.FINGERMODEL_IDS_INVERSE
|
|
|
|
|
|
CONFIG_INFO = settings.CONFIG_INFO
|
|
|
|
|
@@ -17,31 +16,25 @@ CONFIG_INFO = settings.CONFIG_INFO
|
|
|
def raw_loader(data_root, session_paths:dict,
|
|
|
do_rereference=True,
|
|
|
upsampled_epoch_length=1.,
|
|
|
- ori_epoch_length=5,
|
|
|
- unify_label=True,
|
|
|
- mov_trial_ind=[2, 3],
|
|
|
- rest_trial_ind=[4]):
|
|
|
+ ori_epoch_length=5):
|
|
|
"""
|
|
|
Params:
|
|
|
- subj_root:
|
|
|
+ data_root:
|
|
|
session_paths: dict of lists
|
|
|
do_rereference (bool): do common average rereference or not
|
|
|
upsampled_epoch_length (None or float): None: do not do upsampling
|
|
|
ori_epoch_length (int or 'varied'): original epoch length in second
|
|
|
- unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
|
|
|
- mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
|
|
|
- rest_trial_ind: only used when unify_label == True,
|
|
|
"""
|
|
|
raws_loaded = load_sessions(data_root, session_paths, do_rereference)
|
|
|
# process event
|
|
|
raws = []
|
|
|
+ event_id = {}
|
|
|
for (finger_model, raw) in raws_loaded:
|
|
|
fs = raw.info['sfreq']
|
|
|
- events, _ = mne.events_from_annotations(raw)
|
|
|
+ {d: int(d) for d in np.unique(raw.annotations.description)}
|
|
|
+ events, _ = mne.events_from_annotations(raw, event_id={d: int(d) for d in np.unique(raw.annotations.description)})
|
|
|
|
|
|
- if not unify_label:
|
|
|
- mov_trial_ind = [FINGERMODEL_IDS[finger_model]]
|
|
|
- rest_trial_ind = [0]
|
|
|
+ event_id = event_id | {FINGERMODEL_IDS_INVERSE[int(d)]: int(d) for d in np.unique(raw.annotations.description)}
|
|
|
|
|
|
if isinstance(ori_epoch_length, int) or isinstance(ori_epoch_length, float):
|
|
|
trial_duration = ori_epoch_length
|
|
@@ -49,21 +42,21 @@ def raw_loader(data_root, session_paths:dict,
|
|
|
trial_duration = None
|
|
|
else:
|
|
|
raise ValueError(f'Unsupported epoch_length {ori_epoch_length}')
|
|
|
- events = reconstruct_events(events, fs, finger_model,
|
|
|
- mov_trial_ind=mov_trial_ind,
|
|
|
- rest_trial_ind=rest_trial_ind,
|
|
|
- trial_duration=trial_duration,
|
|
|
- use_original_label=not unify_label)
|
|
|
+ events = reconstruct_events(events, fs,
|
|
|
+ trial_duration=trial_duration)
|
|
|
if upsampled_epoch_length is not None:
|
|
|
events = upsample_events(events, int(fs * upsampled_epoch_length))
|
|
|
- annotations = mne.annotations_from_events(events, fs, {FINGERMODEL_IDS[finger_model]: finger_model, FINGERMODEL_IDS['rest']: 'rest'})
|
|
|
+
|
|
|
+ event_desc = {e: FINGERMODEL_IDS_INVERSE[e] for e in np.unique(events[:, 2])}
|
|
|
+ annotations = mne.annotations_from_events(events, fs, event_desc)
|
|
|
raw.set_annotations(annotations)
|
|
|
raws.append(raw)
|
|
|
|
|
|
raws = mne.concatenate_raws(raws)
|
|
|
+
|
|
|
raws.load_data()
|
|
|
|
|
|
- return raws
|
|
|
+ return raws, event_id
|
|
|
|
|
|
|
|
|
def preprocessing(raw, do_rereference=True):
|
|
@@ -78,35 +71,27 @@ def preprocessing(raw, do_rereference=True):
|
|
|
return raw
|
|
|
|
|
|
|
|
|
-def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
|
|
|
+def reconstruct_events(events, fs, trial_duration=5):
|
|
|
"""重构出事件序列中的单独运动事件
|
|
|
Args:
|
|
|
- fs: int
|
|
|
- finger_model:
|
|
|
+ events (np.ndarray):
|
|
|
+ fs (float):
|
|
|
+ trial_duration (float or None or dict): None means variable epoch length, dict means there are different trial durations for different trials
|
|
|
"""
|
|
|
# Trial duration are fixed to be ? seconds.
|
|
|
- # initialRest: 1, miFailed & miSuccess: 2 & 3, rest: 4
|
|
|
- # ignore initialRest
|
|
|
# extract trials
|
|
|
-
|
|
|
- deduplicated_mov = np.diff(np.isin(events[:, 2], mov_trial_ind), prepend=0) == 1
|
|
|
- deduplicated_rest = np.diff(np.isin(events[:, 2], rest_trial_ind), prepend=0) == 1
|
|
|
- trials_ind_deduplicated = np.flatnonzero(np.logical_or(deduplicated_mov, deduplicated_rest))
|
|
|
+
|
|
|
+ trials_ind_deduplicated = np.flatnonzero(np.diff(events[:, 2], prepend=0) != 0)
|
|
|
events_new = events[trials_ind_deduplicated]
|
|
|
if trial_duration is None:
|
|
|
events_new[:-1, 1] = np.diff(events_new[:, 0])
|
|
|
events_new[-1, 1] = events[-1, 0] - events_new[-1, 0]
|
|
|
+ elif isinstance(trial_duration, dict):
|
|
|
+ for e in trial_duration.keys():
|
|
|
+ events_new[events_new[:, 2] == e] = trial_duration[e]
|
|
|
else:
|
|
|
events_new[:, 1] = int(trial_duration * fs)
|
|
|
- events_final = events_new.copy()
|
|
|
- if (not use_original_label) and (finger_model is not None):
|
|
|
- # process mov
|
|
|
- ind_mov = np.flatnonzero(np.isin(events_new[:, 2], mov_trial_ind))
|
|
|
- events_final[ind_mov, 2] = FINGERMODEL_IDS[finger_model]
|
|
|
- # process rest
|
|
|
- ind_rest = np.flatnonzero(np.isin(events_new[:, 2], rest_trial_ind))
|
|
|
- events_final[ind_rest, 2] = 0
|
|
|
- return events_final
|
|
|
+ return events_new
|
|
|
|
|
|
|
|
|
def load_sessions(data_root, session_names: dict, do_rereference=True):
|
|
@@ -171,11 +156,15 @@ def load_neuracle(data_dir, data_type='ecog'):
|
|
|
onset, duration, content = f_evt.readAnnotations()
|
|
|
onset = np.array(onset) - start_time_point * 1e-3 # correct by start time point
|
|
|
onset = (onset * sfreq).astype(np.int64)
|
|
|
+ try:
|
|
|
+ content = content.astype(np.int64) # use original event code
|
|
|
+ except ValueError:
|
|
|
+ event_mapping = {c: i for i, c in enumerate(np.unique(content))}
|
|
|
+ content = [event_mapping[i] for i in content]
|
|
|
|
|
|
duration = (np.array(duration) * sfreq).astype(np.int64)
|
|
|
- event_mapping = {c: i for i, c in enumerate(np.unique(content))}
|
|
|
- event_ids = [event_mapping[i] for i in content]
|
|
|
- events = np.stack((onset, duration, event_ids), axis=1)
|
|
|
+
|
|
|
+ events = np.stack((onset, duration, content), axis=1)
|
|
|
|
|
|
annotations = mne.annotations_from_events(events, sfreq)
|
|
|
raw.set_annotations(annotations)
|