Browse Source

Merge branch 'band_selection' of dk/kraken into master

dk 1 year ago
parent
commit
7c126c9c94

+ 13 - 1
.vscode/launch.json

@@ -34,6 +34,18 @@
             "--model-path", "./static/models/XW01/baseline_rest+flex_11-21-2023-09-59-49.pkl"]
         },
         {
+            "name": "Band selection",
+            "type": "python",
+            "request": "launch",
+            "program": "band_selection.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "--band-min", "5", 
+            "--band-max", "150"]
+        },
+        {
             "name": "Train model",
             "type": "python",
             "request": "launch",
@@ -54,7 +66,7 @@
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "-scth", "0.75",
-            "--model-filename", "riemann_rest+flex_11-23-2023-10-56-58.pkl"]
+            "--model-filename", "riemann_rest+flex_11-27-2023-11-14-08.pkl"]
         },
         {
             "name": "Python: 当前文件",

+ 119 - 0
backend/band_selection.py

@@ -0,0 +1,119 @@
+import os
+import argparse
+import matplotlib.pyplot as plt
+import yaml
+import numpy as np
+from bci_core.frequencybandselection_helpers import freq_selection_class_dis
+from dataloaders import neo
+from sklearn.model_selection import ShuffleSplit
+from settings.config import settings
+
+
+config_info = settings.CONFIG_INFO
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--band-min',
+        dest='b_min',
+        help='Band lower range',
+        default=5.,
+        type=float
+    )
+    parser.add_argument(
+        '--band-max',
+        dest='b_max',
+        help='Band upper range',
+        default=45.,
+        type=float
+    )
+    return parser.parse_args()
+
+
+args = parse_args()
+
+subj_name = args.subj
+
+data_dir = f'./data/{subj_name}/'
+
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
+    info = yaml.safe_load(f)
+sessions = info['sessions']
+event_id = {'rest': 0}
+for f in sessions.keys():
+    event_id[f] = neo.FINGERMODEL_IDS[f]
+
+trial_duration = config_info['buffer_length']
+# preprocess raw
+raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+
+###############################################################################
+# Pipeline with a frequency band selection based on the class distinctiveness
+# ----------------------------------------------------------------------------
+#
+# Step1: Select frequency band maximizing class distinctiveness on
+# training set.
+#
+# Define parameters for frequency band selection
+freq_band = [args.b_min, args.b_max]
+sub_band_width = 4.
+sub_band_step = 2.
+alpha = 0.4
+
+# cross validation
+cv = ShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
+
+# Select frequency band using training set
+best_freq, all_class_dis = \
+    freq_selection_class_dis(raw, freq_band, sub_band_width,
+                             sub_band_step, alpha,
+                             tmin=0., tmax=0.5,
+                             cv=cv,
+                             return_class_dis=True, verbose=False)
+
+print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz')
+
+###############################################################################
+# Plot selected frequency bands
+# ----------------------------------
+#
+# Plot the class distinctiveness values for each sub_band,
+# along with the highlight of the finally selected frequency band.
+
+subband_fmin = np.arange(freq_band[0],
+                              freq_band[1] - sub_band_width + 1.,
+                              sub_band_step)
+subband_fmax = np.arange(freq_band[0] + sub_band_width,
+                              freq_band[1] + 1., sub_band_step)
+n_subband = len(subband_fmin)
+
+subband_fmean = (subband_fmin + subband_fmax) / 2
+
+x = subband_fmean
+
+fig, ax = plt.subplots(1, 1, figsize=(8, 5))
+ax.plot(x, all_class_dis[0], marker='o')
+
+ax.set_ylabel('Class distinctiveness')
+ax.set_xlabel('Filter bank [Hz]')
+ax.set_title('Class distinctiveness value of each subband')
+ax.tick_params(labelsize='large')
+
+fig.tight_layout()
+fig.savefig(os.path.join(data_dir, 'freq_selection.pdf'))
+
+
+print(f'Optimal frequency band for this subject is '
+      f'{best_freq[0][0]} - {best_freq[0][1]} Hz')
+
+plt.show()

+ 2 - 1
backend/bci_core/frequencybandselection_helpers.py

@@ -9,7 +9,7 @@ This file contains helper functions for the frequency band selection example
 import numpy as np
 from mne import Epochs, events_from_annotations
 
-from pyriemann.estimation import Covariances
+from pyriemann.estimation import Covariances, Shrinkage
 from pyriemann.classification import class_distinctiveness
 
 
@@ -163,6 +163,7 @@ def _get_filtered_cov(raw, picks, event_id, fmin, fmax, tmin, tmax, verbose):
     epochs_data = epochs.get_data(units="uV")
 
     cov_data = Covariances().transform(epochs_data)
+    cov_data = Shrinkage().transform(cov_data)
 
     return cov_data, labels
 

+ 0 - 2
backend/bci_core/model.py

@@ -3,10 +3,8 @@ import numpy as np
 from sklearn.linear_model import LogisticRegression
 from pyriemann.tangentspace import TangentSpace
 from pyriemann.preprocessing import Whitening
-
 from sklearn.pipeline import make_pipeline
 from sklearn.base import BaseEstimator, TransformerMixin
-
 from mne.decoding import Vectorizer
 
 

+ 5 - 2
backend/dataloaders/neo.py

@@ -15,6 +15,7 @@ CONFIG_INFO = settings.CONFIG_INFO
 
 
 def raw_preprocessing(data_root, session_paths:dict, 
+                      do_rereference=True,
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5, 
                       unify_label=True,
@@ -24,6 +25,7 @@ def raw_preprocessing(data_root, session_paths:dict,
     Params:
         subj_root: 
         session_paths: dict of lists
+        do_rereference (bool): do common average rereference or not
         upsampled_epoch_length: 
         ori_epoch_length (int or 'varied'): original epoch length in second
         unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
@@ -61,8 +63,9 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws.load_data()
 
-    # common average
-    raws.set_eeg_reference('average')
+    if do_rereference:
+        # common average
+        raws.set_eeg_reference('average')
     # high pass
     raws = raws.filter(1, None)
     # filter 50Hz

+ 10 - 11
backend/training.py

@@ -41,26 +41,24 @@ def parse_args():
     return parser.parse_args()
 
 
-def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
+def train_model(raw, event_id, trial_duration=1., model_type='baseline', **model_kwargs):
     """
     """
     events, _ = mne.events_from_annotations(raw, event_id=event_id)
     if model_type.lower() == 'baseline':
-        model = _train_baseline_model(raw, events, duration=trial_duration)
+        model = _train_baseline_model(raw, events, duration=trial_duration, **model_kwargs)
     elif model_type.lower() == 'riemann':
-        # TODO: load subject config
-        model = _train_riemann_model(raw, events, duration=trial_duration)
+        model = _train_riemann_model(raw, events, duration=trial_duration, **model_kwargs)
     else:
         raise NotImplementedError
     return model
 
 
-def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)], hg_bands=[(55, 95), (105, 145)]):
+def _train_riemann_model(raw, events, duration=1., lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
     fs = raw.info['sfreq']
     n_ch = len(raw.ch_names)
-    feat_extractor = feature_extractors.FeatExtractor(fs, lfb_bands, hg_bands)
+    feat_extractor = feature_extractors.FeatExtractor(fs, lf_bands, hg_bands)
     filtered_data = feat_extractor.transform(raw.get_data())
-    # TODO: find proper latency
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
     y = events[:, -1]
     
@@ -68,7 +66,7 @@ def _train_riemann_model(raw, events, duration=1., lfb_bands=[(15, 35), (35, 55)
     X = scaler.fit_transform(X)
 
     # compute covariance
-    lfb_dim = len(lfb_bands) * n_ch
+    lfb_dim = len(lf_bands) * n_ch
     hgs_dim = len(hg_bands) * n_ch
     cov_model = BlockCovariances([lfb_dim, hgs_dim], estimator='lwf')
     X_cov = cov_model.fit_transform(X)
@@ -130,12 +128,13 @@ if __name__ == '__main__':
     args = parse_args()
     subj_name = args.subj
     model_type = args.model_type
-    # TODO: load subject config
-    # include frequency band, model_type, upsampled_trial_duration
 
     data_dir = f'./data/{subj_name}/'
     model_dir = './static/models/'
 
+    with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
+        model_config = yaml.safe_load(f)[model_type]
+
     with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
@@ -148,7 +147,7 @@ if __name__ == '__main__':
     raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # train model
-    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
+    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
     
     # save
     model_saver(model, model_dir, model_type, subj_name, event_id)

+ 0 - 1
backend/validation.py

@@ -150,7 +150,6 @@ def _event_to_stim_channel(events, time_length):
 if __name__ == '__main__':
     args = parse_args()
     subj_name = args.subj
-    # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/'