123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- import os
- import argparse
- import matplotlib.pyplot as plt
- import yaml
- import numpy as np
- from bci_core.frequencybandselection_helpers import freq_selection_class_dis
- from dataloaders import neo
- from sklearn.model_selection import ShuffleSplit
- from settings.config import settings
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--band-min',
- dest='b_min',
- help='Band lower range',
- default=5.,
- type=float
- )
- parser.add_argument(
- '--band-max',
- dest='b_max',
- help='Band upper range',
- default=45.,
- type=float
- )
- return parser.parse_args()
- args = parse_args()
- subj_name = args.subj
- data_dir = f'./data/{subj_name}/'
- with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
- trial_duration = config_info['buffer_length']
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions, do_rereference=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5)
- ###############################################################################
- # Pipeline with a frequency band selection based on the class distinctiveness
- # ----------------------------------------------------------------------------
- #
- # Step1: Select frequency band maximizing class distinctiveness on
- # training set.
- #
- # Define parameters for frequency band selection
- freq_band = [args.b_min, args.b_max]
- sub_band_width = 5.
- sub_band_step = 5.
- alpha = 0.4
- # cross validation
- cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42)
- # Select frequency band using training set
- best_freq, all_class_dis = \
- freq_selection_class_dis(raw, freq_band, sub_band_width,
- sub_band_step, alpha,
- tmin=0., tmax=0.5,
- cv=cv,
- return_class_dis=True, verbose=False)
- print(f'Selected frequency band : {best_freq[0][0]} - {best_freq[0][1]} Hz')
- ###############################################################################
- # Plot selected frequency bands
- # ----------------------------------
- #
- # Plot the class distinctiveness values for each sub_band,
- # along with the highlight of the finally selected frequency band.
- subband_fmin = np.arange(freq_band[0],
- freq_band[1] - sub_band_width + 1.,
- sub_band_step)
- subband_fmax = np.arange(freq_band[0] + sub_band_width,
- freq_band[1] + 1., sub_band_step)
- n_subband = len(subband_fmin)
- subband_fmean = (subband_fmin + subband_fmax) / 2
- x = subband_fmean
- fig, ax = plt.subplots(1, 1, figsize=(8, 5))
- ax.plot(x, all_class_dis[0], marker='o')
- ax.set_ylabel('Class distinctiveness')
- ax.set_xlabel('Filter bank [Hz]')
- ax.set_title('Class distinctiveness value of each subband')
- ax.tick_params(labelsize='large')
- fig.tight_layout()
- fig.savefig(os.path.join(data_dir, 'freq_selection.pdf'))
- print(f'Optimal frequency band for this subject is '
- f'{best_freq[0][0]} - {best_freq[0][1]} Hz')
- plt.show()
|