|
@@ -0,0 +1,119 @@
|
|
|
+import os
|
|
|
+import argparse
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import yaml
|
|
|
+import numpy as np
|
|
|
+from bci_core.frequencybandselection_helpers import freq_selection_class_dis
|
|
|
+from dataloaders import neo
|
|
|
+from sklearn.model_selection import ShuffleSplit
|
|
|
+from settings.config import settings
|
|
|
+
|
|
|
+
|
|
|
+config_info = settings.CONFIG_INFO
|
|
|
+
|
|
|
+
|
|
|
+def parse_args():
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description='Model validation'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--subj',
|
|
|
+ dest='subj',
|
|
|
+ help='Subject name',
|
|
|
+ default=None,
|
|
|
+ type=str
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--band-min',
|
|
|
+ dest='b_min',
|
|
|
+ help='Band lower range',
|
|
|
+ default=5.,
|
|
|
+ type=float
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--band-max',
|
|
|
+ dest='b_max',
|
|
|
+ help='Band upper range',
|
|
|
+ default=45.,
|
|
|
+ type=float
|
|
|
+ )
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+args = parse_args()
|
|
|
+
|
|
|
+subj_name = args.subj
|
|
|
+
|
|
|
+data_dir = f'./data/{subj_name}/'
|
|
|
+
|
|
|
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
|
|
|
+ info = yaml.safe_load(f)
|
|
|
+sessions = info['sessions']
|
|
|
+event_id = {'rest': 0}
|
|
|
+for f in sessions.keys():
|
|
|
+ event_id[f] = neo.FINGERMODEL_IDS[f]
|
|
|
+
|
|
|
+trial_duration = config_info['buffer_length']
|
|
|
+# preprocess raw
|
|
|
+raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
|
|
|
+
|
|
|
+###############################################################################
|
|
|
+# Pipeline with a frequency band selection based on the class distinctiveness
|
|
|
+# ----------------------------------------------------------------------------
|
|
|
+#
|
|
|
+# Step1: Select frequency band maximizing class distinctiveness on
|
|
|
+# training set.
|
|
|
+#
|
|
|
+# Define parameters for frequency band selection
|
|
|
+freq_band = [args.b_min, args.b_max]
|
|
|
+sub_band_width = 4.
|
|
|
+sub_band_step = 2.
|
|
|
+alpha = 0.4
|
|
|
+
|
|
|
+# cross validation
|
|
|
+cv = ShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
|
|
|
+
|
|
|
+# Select frequency band using training set
|
|
|
+best_freq, all_class_dis = \
|
|
|
+ freq_selection_class_dis(raw, freq_band, sub_band_width,
|
|
|
+ sub_band_step, alpha,
|
|
|
+ tmin=0., tmax=0.5,
|
|
|
+ cv=cv,
|
|
|
+ return_class_dis=True, verbose=False)
|
|
|
+
|
|
|
+print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz')
|
|
|
+
|
|
|
+###############################################################################
|
|
|
+# Plot selected frequency bands
|
|
|
+# ----------------------------------
|
|
|
+#
|
|
|
+# Plot the class distinctiveness values for each sub_band,
|
|
|
+# along with the highlight of the finally selected frequency band.
|
|
|
+
|
|
|
+subband_fmin = np.arange(freq_band[0],
|
|
|
+ freq_band[1] - sub_band_width + 1.,
|
|
|
+ sub_band_step)
|
|
|
+subband_fmax = np.arange(freq_band[0] + sub_band_width,
|
|
|
+ freq_band[1] + 1., sub_band_step)
|
|
|
+n_subband = len(subband_fmin)
|
|
|
+
|
|
|
+subband_fmean = (subband_fmin + subband_fmax) / 2
|
|
|
+
|
|
|
+x = subband_fmean
|
|
|
+
|
|
|
+fig, ax = plt.subplots(1, 1, figsize=(8, 5))
|
|
|
+ax.plot(x, all_class_dis[0], marker='o')
|
|
|
+
|
|
|
+ax.set_ylabel('Class distinctiveness')
|
|
|
+ax.set_xlabel('Filter bank [Hz]')
|
|
|
+ax.set_title('Class distinctiveness value of each subband')
|
|
|
+ax.tick_params(labelsize='large')
|
|
|
+
|
|
|
+fig.tight_layout()
|
|
|
+fig.savefig('freq_selection.pdf')
|
|
|
+
|
|
|
+
|
|
|
+print(f'Optimal frequency band for this subject is '
|
|
|
+ f'{best_freq[0][0]} - {best_freq[0][1]} Hz')
|
|
|
+
|
|
|
+plt.show()
|