import os import argparse import matplotlib.pyplot as plt import yaml import numpy as np from bci_core.frequencybandselection_helpers import freq_selection_class_dis from dataloaders import neo from sklearn.model_selection import ShuffleSplit from settings.config import settings config_info = settings.CONFIG_INFO def parse_args(): parser = argparse.ArgumentParser( description='Model validation' ) parser.add_argument( '--subj', dest='subj', help='Subject name', default=None, type=str ) parser.add_argument( '--band-min', dest='b_min', help='Band lower range', default=5., type=float ) parser.add_argument( '--band-max', dest='b_max', help='Band upper range', default=200., type=float ) parser.add_argument( '--band-div', help='Dividing point of low frequency bands and high frequency bands', default=50., type=float ) return parser.parse_args() args = parse_args() subj_name = args.subj data_dir = os.path.join(settings.DATA_PATH, subj_name) with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f: info = yaml.safe_load(f) sessions = info['sessions'] trial_duration = config_info['buffer_length'] ori_epoch_length = info.get('ori_epoch_length', 5.) # preprocess raw raw, event_id = neo.raw_loader(data_dir, sessions, reref_method=config_info['reref'], upsampled_epoch_length=trial_duration, ori_epoch_length=ori_epoch_length) ############################################################################### # Pipeline with a frequency band selection based on the class distinctiveness # ---------------------------------------------------------------------------- # # Step1: Select frequency band maximizing class distinctiveness on # training set. # # Define parameters for frequency band selection freq_band = [args.b_min, args.b_max] alpha = 0.4 subband_fmean = np.logspace(np.log10(args.b_min), np.log10(args.b_max), 30) f_step = (np.log10(args.b_max) - np.log10(args.b_min)) / 30 subband_fmin = subband_fmean * (10 ** (-f_step)) subband_fmax = subband_fmean * (10 ** f_step) # cross validation cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42) # Select frequency band using training set best_freq, all_class_dis = \ freq_selection_class_dis(raw, subband_fmin, subband_fmax, alpha, args.band_div, tmin=0., tmax=0.5, cv=cv, return_class_dis=True, verbose=False) print(f'Selected frequency band (low) : {best_freq[0][0][0]} - {best_freq[0][0][1]} Hz') print(f'Selected frequency band (high) : {best_freq[0][1][0]} - {best_freq[0][1][1]} Hz') # split high frequency band n_bands = np.round((best_freq[0][1][1] - best_freq[0][1][0]) / 50 + 1).astype(int) split_high_freqs = np.logspace(np.log10(best_freq[0][1][0]), np.log10(best_freq[0][1][1]), n_bands) print('Splited high frequency bands: ', split_high_freqs) ############################################################################### # Plot selected frequency bands # ---------------------------------- # # Plot the class distinctiveness values for each sub_band, # along with the highlight of the finally selected frequency band. fig, ax = plt.subplots(1, 1, figsize=(8, 5)) ax.plot(subband_fmean, all_class_dis[0], marker='o') ax.axvspan(best_freq[0][0][0], best_freq[0][0][1], color='orange', alpha=0.5) ax.axvspan(best_freq[0][1][0], best_freq[0][1][1], color='orange', alpha=0.5) ax.set_ylabel('Class distinctiveness') ax.set_xlabel('Filter bank [Hz]') ax.set_title('Class distinctiveness value of each subband') ax.tick_params(labelsize='large') fig.tight_layout() fig.savefig(os.path.join(data_dir, 'freq_selection.pdf')) plt.show()