Browse Source

Feat: band selection

dk 1 year ago
parent
commit
5c4863341d
3 changed files with 136 additions and 2 deletions
  1. 12 0
      .vscode/launch.json
  2. 119 0
      backend/band_selection.py
  3. 5 2
      backend/dataloaders/neo.py

+ 12 - 0
.vscode/launch.json

@@ -34,6 +34,18 @@
             "--model-path", "./static/models/XW01/baseline_rest+flex_11-21-2023-09-59-49.pkl"]
         },
         {
+            "name": "Band selection",
+            "type": "python",
+            "request": "launch",
+            "program": "band_selection.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "--band-min", "5", 
+            "--band-max", "55"]
+        },
+        {
             "name": "Train model",
             "type": "python",
             "request": "launch",

+ 119 - 0
backend/band_selection.py

@@ -0,0 +1,119 @@
+import os
+import argparse
+import matplotlib.pyplot as plt
+import yaml
+import numpy as np
+from bci_core.frequencybandselection_helpers import freq_selection_class_dis
+from dataloaders import neo
+from sklearn.model_selection import ShuffleSplit
+from settings.config import settings
+
+
+config_info = settings.CONFIG_INFO
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--band-min',
+        dest='b_min',
+        help='Band lower range',
+        default=5.,
+        type=float
+    )
+    parser.add_argument(
+        '--band-max',
+        dest='b_max',
+        help='Band upper range',
+        default=45.,
+        type=float
+    )
+    return parser.parse_args()
+
+
+args = parse_args()
+
+subj_name = args.subj
+
+data_dir = f'./data/{subj_name}/'
+
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
+    info = yaml.safe_load(f)
+sessions = info['sessions']
+event_id = {'rest': 0}
+for f in sessions.keys():
+    event_id[f] = neo.FINGERMODEL_IDS[f]
+
+trial_duration = config_info['buffer_length']
+# preprocess raw
+raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+
+###############################################################################
+# Pipeline with a frequency band selection based on the class distinctiveness
+# ----------------------------------------------------------------------------
+#
+# Step1: Select frequency band maximizing class distinctiveness on
+# training set.
+#
+# Define parameters for frequency band selection
+freq_band = [args.b_min, args.b_max]
+sub_band_width = 4.
+sub_band_step = 2.
+alpha = 0.4
+
+# cross validation
+cv = ShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
+
+# Select frequency band using training set
+best_freq, all_class_dis = \
+    freq_selection_class_dis(raw, freq_band, sub_band_width,
+                             sub_band_step, alpha,
+                             tmin=0., tmax=0.5,
+                             cv=cv,
+                             return_class_dis=True, verbose=False)
+
+print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz')
+
+###############################################################################
+# Plot selected frequency bands
+# ----------------------------------
+#
+# Plot the class distinctiveness values for each sub_band,
+# along with the highlight of the finally selected frequency band.
+
+subband_fmin = np.arange(freq_band[0],
+                              freq_band[1] - sub_band_width + 1.,
+                              sub_band_step)
+subband_fmax = np.arange(freq_band[0] + sub_band_width,
+                              freq_band[1] + 1., sub_band_step)
+n_subband = len(subband_fmin)
+
+subband_fmean = (subband_fmin + subband_fmax) / 2
+
+x = subband_fmean
+
+fig, ax = plt.subplots(1, 1, figsize=(8, 5))
+ax.plot(x, all_class_dis[0], marker='o')
+
+ax.set_ylabel('Class distinctiveness')
+ax.set_xlabel('Filter bank [Hz]')
+ax.set_title('Class distinctiveness value of each subband')
+ax.tick_params(labelsize='large')
+
+fig.tight_layout()
+fig.savefig('freq_selection.pdf')
+
+
+print(f'Optimal frequency band for this subject is '
+      f'{best_freq[0][0]} - {best_freq[0][1]} Hz')
+
+plt.show()

+ 5 - 2
backend/dataloaders/neo.py

@@ -15,6 +15,7 @@ CONFIG_INFO = settings.CONFIG_INFO
 
 
 def raw_preprocessing(data_root, session_paths:dict, 
+                      do_rereference=True,
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5, 
                       unify_label=True,
@@ -24,6 +25,7 @@ def raw_preprocessing(data_root, session_paths:dict,
     Params:
         subj_root: 
         session_paths: dict of lists
+        do_rereference (bool): do common average rereference or not
         upsampled_epoch_length: 
         ori_epoch_length (int or 'varied'): original epoch length in second
         unify_label (True, original data didn't use unified event label, do unification (old pony format); False use original)
@@ -61,8 +63,9 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws.load_data()
 
-    # common average
-    raws.set_eeg_reference('average')
+    if do_rereference:
+        # common average
+        raws.set_eeg_reference('average')
     # high pass
     raws = raws.filter(1, None)
     # filter 50Hz