123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import os
- import argparse
- import matplotlib.pyplot as plt
- import yaml
- import numpy as np
- from bci_core.frequencybandselection_helpers import freq_selection_class_dis
- from dataloaders import neo
- from sklearn.model_selection import ShuffleSplit
- from settings.config import settings
- config_info = settings.CONFIG_INFO
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model validation'
- )
- parser.add_argument(
- '--subj',
- dest='subj',
- help='Subject name',
- default=None,
- type=str
- )
- parser.add_argument(
- '--band-min',
- dest='b_min',
- help='Band lower range',
- default=5.,
- type=float
- )
- parser.add_argument(
- '--band-max',
- dest='b_max',
- help='Band upper range',
- default=200.,
- type=float
- )
- parser.add_argument(
- '--band-div',
- help='Dividing point of low frequency bands and high frequency bands',
- default=50.,
- type=float
- )
- return parser.parse_args()
- args = parse_args()
- subj_name = args.subj
- data_dir = os.path.join(settings.DATA_PATH, subj_name)
- with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
- info = yaml.safe_load(f)
- sessions = info['sessions']
- trial_duration = config_info['buffer_length']
- ori_epoch_length = info.get('ori_epoch_length', 5.)
- # preprocess raw
- raw, event_id = neo.raw_loader(data_dir, sessions, reref_method=config_info['reref'], upsampled_epoch_length=trial_duration, ori_epoch_length=ori_epoch_length)
- ###############################################################################
- # Pipeline with a frequency band selection based on the class distinctiveness
- # ----------------------------------------------------------------------------
- #
- # Step1: Select frequency band maximizing class distinctiveness on
- # training set.
- #
- # Define parameters for frequency band selection
- freq_band = [args.b_min, args.b_max]
- alpha = 0.4
- subband_fmean = np.logspace(np.log10(args.b_min), np.log10(args.b_max), 30)
- f_step = (np.log10(args.b_max) - np.log10(args.b_min)) / 30
- subband_fmin = subband_fmean * (10 ** (-f_step))
- subband_fmax = subband_fmean * (10 ** f_step)
- # cross validation
- cv = ShuffleSplit(n_splits=1, test_size=None, random_state=42)
- # Select frequency band using training set
- best_freq, all_class_dis = \
- freq_selection_class_dis(raw, subband_fmin, subband_fmax, alpha, args.band_div,
- tmin=0., tmax=0.5,
- cv=cv,
- return_class_dis=True, verbose=False)
- print(f'Selected frequency band (low) : {best_freq[0][0][0]} - {best_freq[0][0][1]} Hz')
- print(f'Selected frequency band (high) : {best_freq[0][1][0]} - {best_freq[0][1][1]} Hz')
- # split high frequency band
- n_bands = np.round((best_freq[0][1][1] - best_freq[0][1][0]) / 50 + 1).astype(int)
- split_high_freqs = np.logspace(np.log10(best_freq[0][1][0]), np.log10(best_freq[0][1][1]), n_bands)
- print('Splited high frequency bands: ', split_high_freqs)
- ###############################################################################
- # Plot selected frequency bands
- # ----------------------------------
- #
- # Plot the class distinctiveness values for each sub_band,
- # along with the highlight of the finally selected frequency band.
- fig, ax = plt.subplots(1, 1, figsize=(8, 5))
- ax.plot(subband_fmean, all_class_dis[0], marker='o')
- ax.axvspan(best_freq[0][0][0], best_freq[0][0][1], color='orange', alpha=0.5)
- ax.axvspan(best_freq[0][1][0], best_freq[0][1][1], color='orange', alpha=0.5)
- ax.set_ylabel('Class distinctiveness')
- ax.set_xlabel('Filter bank [Hz]')
- ax.set_title('Class distinctiveness value of each subband')
- ax.tick_params(labelsize='large')
- fig.tight_layout()
- fig.savefig(os.path.join(data_dir, 'freq_selection.pdf'))
- plt.show()
|