import os import argparse import matplotlib.pyplot as plt import yaml import numpy as np from bci_core.frequencybandselection_helpers import freq_selection_class_dis from dataloaders import neo from sklearn.model_selection import ShuffleSplit from settings.config import settings config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--band-min', dest='b_min', help='Band lower range', default=5., type=float ) parser.add_argument( '--band-max', dest='b_max', help='Band upper range', default=45., type=float ) return parser.parse_args() args = parse_args() subj_name = args.subj data_dir = f'./data/{subj_name}/' with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] trial_duration = config_info['buffer_length'] # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, do_rereference=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5) ############################################################################### # Pipeline with a frequency band selection based on the class distinctiveness # ---------------------------------------------------------------------------- # # Step1: Select frequency band maximizing class distinctiveness on # training set. # # Define parameters for frequency band selection freq_band = [args.b_min, args.b_max] sub_band_width = 5. sub_band_step = 5. alpha = 0.4 # cross validation cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42) # Select frequency band using training set best_freq, all_class_dis = \ freq_selection_class_dis(raw, freq_band, sub_band_width, sub_band_step, alpha, tmin=0., tmax=0.5, cv=cv, return_class_dis=True, verbose=False) print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz') ############################################################################### # Plot selected frequency bands # ---------------------------------- # # Plot the class distinctiveness values for each sub_band, # along with the highlight of the finally selected frequency band. subband_fmin = np.arange(freq_band[0], freq_band[1] - sub_band_width + 1., sub_band_step) subband_fmax = np.arange(freq_band[0] + sub_band_width, freq_band[1] + 1., sub_band_step) n_subband = len(subband_fmin) subband_fmean = (subband_fmin + subband_fmax) / 2 x = subband_fmean fig, ax = plt.subplots(1, 1, figsize=(8, 5)) ax.plot(x, all_class_dis[0], marker='o') ax.set_ylabel('Class distinctiveness') ax.set_xlabel('Filter bank [Hz]') ax.set_title('Class distinctiveness value of each subband') ax.tick_params(labelsize='large') fig.tight_layout() fig.savefig(os.path.join(data_dir, 'freq_selection.pdf')) print(f'Optimal frequency band for this subject is ' f'{best_freq[0][0]} - {best_freq[0][1]} Hz') plt.show()