|
@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
|
|
|
import mne
|
|
|
import yaml
|
|
|
import os
|
|
|
+import argparse
|
|
|
import logging
|
|
|
from scipy import stats
|
|
|
from dataloaders import neo
|
|
@@ -20,6 +21,35 @@ logging.basicConfig(level=logging.DEBUG)
|
|
|
config_info = settings.CONFIG_INFO
|
|
|
|
|
|
|
|
|
+def parse_args():
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description='Model validation'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--subj',
|
|
|
+ dest='subj',
|
|
|
+ help='Subject name',
|
|
|
+ default=None,
|
|
|
+ type=str
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--state-change-threshold',
|
|
|
+ '-scth',
|
|
|
+ dest='state_change_threshold',
|
|
|
+ help='Threshold for HMM state change',
|
|
|
+ default=0.75,
|
|
|
+ type=float
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--model-filename',
|
|
|
+ dest='model_filename',
|
|
|
+ help='Model filename',
|
|
|
+ default=None,
|
|
|
+ type=str
|
|
|
+ )
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
class DataGenerator:
|
|
|
def __init__(self, fs, X, epoch_step=1.):
|
|
|
self.fs = int(fs)
|
|
@@ -89,7 +119,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length
|
|
|
|
|
|
corr, _ = stats.pearsonr(stim_pred, stim_true)
|
|
|
|
|
|
- fig_pred, ax = plt.subplots(1, 1, sharex=True, sharey=False)
|
|
|
+ fig_pred, ax = plt.subplots(3, 1, sharex=True, sharey=False)
|
|
|
ax[0].set_title('pred')
|
|
|
ax[0].plot(raw_val.times, stim_pred)
|
|
|
ax[1].set_title('true')
|
|
@@ -118,12 +148,13 @@ def _event_to_stim_channel(events, time_length):
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- # TODO: argparse
|
|
|
- subj_name = 'XW01'
|
|
|
+ args = parse_args()
|
|
|
+ subj_name = args.subj
|
|
|
# TODO: load subject config
|
|
|
|
|
|
data_dir = f'./data/{subj_name}/'
|
|
|
- model_path = f'./static/models/{subj_name}/riemann_rest+flex_11-21-2023-21-23-15.pkl'
|
|
|
+
|
|
|
+ model_path = f'./static/models/{subj_name}/{args.model_filename}'
|
|
|
with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
|
|
|
info = yaml.safe_load(f)
|
|
|
sessions = info['sessions']
|
|
@@ -138,7 +169,7 @@ if __name__ == '__main__':
|
|
|
metrics, fig_erds, fig_pred = validation(raw,
|
|
|
event_id,
|
|
|
model=model_path,
|
|
|
- state_change_threshold=0.75,
|
|
|
+ state_change_threshold=args.state_change_threshold,
|
|
|
step_length=config_info['buffer_length'])
|
|
|
fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
|
|
|
fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))
|