Pārlūkot izejas kodu

增加hmm训练脚本,修改neo loader方法名称

dk 1 gadu atpakaļ
vecāks
revīzija
c64596cbe5

+ 11 - 0
.vscode/launch.json

@@ -83,6 +83,17 @@
             "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         {
+            "name": "Train hmm",
+            "type": "python",
+            "request": "launch",
+            "program": "train_hmm.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
+        },
+        {
             "name": "Python: 当前文件",
             "type": "python",
             "request": "launch",

+ 1 - 1
backend/band_selection.py

@@ -55,7 +55,7 @@ for f in sessions.keys():
 
 trial_duration = config_info['buffer_length']
 # preprocess raw
-raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+raw = neo.raw_loader(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
 ###############################################################################
 # Pipeline with a frequency band selection based on the class distinctiveness

+ 30 - 31
backend/bci_core/online.py

@@ -3,10 +3,9 @@ import numpy as np
 import random
 import logging
 from scipy import signal
-import mne
-from .feature_extractors import filterbank_extractor
 from .utils import parse_model_type
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -21,23 +20,9 @@ class Controller:
     """
     def __init__(self,
                  virtual_feedback_rate=1., 
-                 model_path=None,
-                 state_trans_prob=0.8,
-                 state_change_threshold=0.6):
-        if (model_path is None) or (model_path == 'None'):
-            self.real_feedback_model = None
-        else:
-            self.model_type, _ = parse_model_type(model_path)
-            if self.model_type == 'baseline':
-                self.real_feedback_model = BaselineHMM(model_path, 
-                state_trans_prob=state_trans_prob,
-                state_change_threshold=state_change_threshold)
-            elif self.model_type == 'riemann':
-                self.real_feedback_model = RiemannHMM(model_path, 
-                state_trans_prob=state_trans_prob,
-                state_change_threshold=state_change_threshold)
-            else:
-                raise NotImplementedError
+                 real_feedback_model=None):
+        
+        self.real_feedback_model = real_feedback_model
         self.virtual_feedback_rate = virtual_feedback_rate
 
     def step_decision(self, data, true_label=None):
@@ -113,25 +98,29 @@ class Controller:
 
 
 class HMMModel:
-    def __init__(self, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.7):
+    def __init__(self, transmat=None, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.5):
         self.n_classes = n_classes
-        self._probability = np.ones(n_classes) / n_classes
-        self._last_state = 0
+        self._probability = np.zeros(n_classes)
+        self.reset_state()
 
         self.state_change_threshold = state_change_threshold
 
-        # TODO: train with daily use data
-        # build state transition matrix
-        self.state_trans_matrix = np.zeros((n_classes, n_classes))
-        # fill diagonal
-        np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
-        # fill 0 -> each state, 
-        self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
-        self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
+        if transmat is None:
+            # build state transition matrix
+            self.state_trans_matrix = np.zeros((n_classes, n_classes))
+            # fill diagonal
+            np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
+            # fill 0 -> each state, 
+            self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
+            self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
+        else:
+            if isinstance(transmat, str):
+                transmat = np.loadtxt(transmat)
+            self.state_trans_matrix = transmat
 
     def reset_state(self):
+        self._probability[0] = 1.
         self._last_state = 0
-        self._probability = np.ones(self.n_classes) / self.n_classes
     
     def set_current_state(self, current_state):
         self._last_state = current_state
@@ -227,3 +216,13 @@ class RiemannHMM(HMMModel):
         # predict proba
         p = self.model.predict_proba(data).squeeze()
         return p
+
+
+def model_loader(model_path, **kwargs):
+    model_type, _ = parse_model_type(model_path)
+    if model_type == 'baseline':
+        return BaselineHMM(model_path, **kwargs)
+    elif model_type == 'riemann':
+        return RiemannHMM(model_path, **kwargs)
+    else:
+        raise ValueError(f'Unexpected model type: {model_type}, expect "baseline" or "riemann"')

+ 16 - 8
backend/dataloaders/neo.py

@@ -14,7 +14,7 @@ FINGERMODEL_IDS = settings.FINGERMODEL_IDS
 CONFIG_INFO = settings.CONFIG_INFO
 
 
-def raw_preprocessing(data_root, session_paths:dict, 
+def raw_loader(data_root, session_paths:dict, 
                       do_rereference=True,
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5, 
@@ -32,7 +32,7 @@ def raw_preprocessing(data_root, session_paths:dict,
         mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
         rest_trial_ind: only used when unify_label == True, 
     """
-    raws_loaded = load_sessions(data_root, session_paths)
+    raws_loaded = load_sessions(data_root, session_paths, do_rereference)
     # process event
     raws = []
     for (finger_model, raw) in raws_loaded:
@@ -63,15 +63,19 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws.load_data()
 
+    return raws
+
+
+def preprocessing(raw, do_rereference=True):
+    raw.load_data()
     if do_rereference:
         # common average
-        raws.set_eeg_reference('average')
+        raw.set_eeg_reference('average')
     # high pass
-    raws = raws.filter(1, None)
+    raw = raw.filter(1, None)
     # filter 50Hz
-    raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
-
-    return raws
+    raw = raw.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
+    return raw
 
 
 def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
@@ -105,7 +109,7 @@ def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind
     return events_final
 
 
-def load_sessions(data_root, session_names: dict):
+def load_sessions(data_root, session_names: dict, do_rereference=True):
     # return raws for different finger models on an interleaved manner
     raw_cnt = sum(len(session_names[k]) for k in session_names)
     raws = []
@@ -124,6 +128,10 @@ def load_sessions(data_root, session_names: dict):
                 # kraken format
                 data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
                 raw = mne.io.read_raw_bdf(data_file)
+            # preprocess raw
+            raw = preprocessing(raw, do_rereference)
+
+            # append list
             raws.append((finger_model, raw))
     return raws  
 

+ 13 - 11
backend/online_sim.py

@@ -78,11 +78,11 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
-def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
+def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
     val_data = raw.get_data()
     fs = raw.info['sfreq']
 
-    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
+    data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
 
     decision_with_hmm = []
     decision_without_hmm = []
@@ -127,7 +127,8 @@ def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
 def simulation(raw_val, event_id, model, 
                state_trans_prob=0.8,
                state_change_threshold=0.8, 
-               step_length=1., 
+               epoch_length=1., 
+               step_length=0.1,
                event_trial_length=5.):
     """模型验证接口,使用指定数据进行验证,绘制ersd map
     Args:
@@ -135,7 +136,8 @@ def simulation(raw_val, event_id, model,
         event_id (dict)
         model: validate existing model, 
         state_change_threshold (float): default 0.8
-        step_length (float): batch data step length, default 1. (s)
+        epoch_length (float): batch data length, default 1 (s)
+        step_length (float): data step length, default 0.1 (s)
         event_trial_length (float): 
 
     Returns:
@@ -152,17 +154,16 @@ def simulation(raw_val, event_id, model,
                                         mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
                                         use_original_label=True)
     
-    
-    controller = online.Controller(0, model, 
-                                   state_trans_prob=state_trans_prob,
-                                   state_change_threshold=state_change_threshold)
-    model_hmm = controller.real_feedback_model
+    model_hmm = online.model_loader(model, 
+                                    state_trans_prob=state_trans_prob,
+                                    state_change_threshold=state_change_threshold)
 
     # run with and without hmm
     fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
                                                           events_val, 
                                                           model_hmm, 
-                                                          step_length, 
+                                                          epoch_length, 
+                                                          step_length,
                                                           event_trial_length=event_trial_length)
 
     return metric_hmm, metric_naive, fig_pred
@@ -203,7 +204,7 @@ if __name__ == '__main__':
     
     # preprocess raw
     trial_time = 5.
-    raw = neo.raw_preprocessing(data_dir, sessions, 
+    raw = neo.raw_loader(data_dir, sessions, 
                                 unify_label=True, 
                                 ori_epoch_length=trial_time, 
                                 mov_trial_ind=[2], 
@@ -216,6 +217,7 @@ if __name__ == '__main__':
                                              model=model_path, 
                                              state_trans_prob=args.state_trans_prob,
                                              state_change_threshold=args.state_change_threshold,
+                                             epoch_length=config_info['buffer_length'],
                                              step_length=config_info['buffer_length'],
                                              event_trial_length=trial_time)
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   

+ 1 - 1
backend/tests/test_neoloader.py

@@ -10,7 +10,7 @@ class TestDataloader(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw = neo.raw_loader(root_path, sessions, unify_label=True)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
         self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))

+ 6 - 4
backend/tests/test_online.py

@@ -33,7 +33,8 @@ class TestOnline(unittest.TestCase):
         return super().tearDownClass()
     
     def test_step_feedback(self):
-        controller = online.Controller(0, self.model_path)
+        model_hmm = online.model_loader(self.model_path)
+        controller = online.Controller(0, model_hmm)
         rets = []
         for time, data in self.data_gen.loop():
             cls = controller.step_decision(data)
@@ -61,7 +62,8 @@ class TestOnline(unittest.TestCase):
         self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
 
     def test_real_feedback(self):
-        controller = online.Controller(0, self.model_path)
+        model_hmm = online.model_loader(self.model_path)
+        controller = online.Controller(0, model_hmm)
         rets = []
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
             cls = controller.decision(data)
@@ -74,7 +76,7 @@ class TestHMM(unittest.TestCase):
         # binary
         probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
         true_state = [-1, -1, 1, -1, -1, -1, 0]
-        model = online.HMMModel(2, state_trans_prob=0.9, state_change_threshold=0.7)
+        model = online.HMMModel(transmat=None, n_classes=2, state_trans_prob=0.9, state_change_threshold=0.7)
         states = []
         for p in probs:
             cur_state = model.update_state(p)
@@ -84,7 +86,7 @@ class TestHMM(unittest.TestCase):
         # triple
         probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
         true_state = [-1, 1, -1, -1, 0, 2]
-        model = online.HMMModel(3, state_trans_prob=0.9)
+        model = online.HMMModel(transmat=None, n_classes=3, state_trans_prob=0.9)
         states = []
         for p in probs:
             cur_state = model.update_state(p)

+ 1 - 1
backend/tests/test_training.py

@@ -19,7 +19,7 @@ class TestTraining(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw = neo.raw_loader(root_path, sessions, unify_label=True)
         raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
         cls.raw = raw
     

+ 2 - 1
backend/tests/test_validation.py

@@ -51,9 +51,10 @@ class TestOnlineSim(unittest.TestCase):
         self.assertEqual(recall, 1)
 
     def test_sim(self):
-        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
+        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7, epoch_length=1., step_length=0.1, state_trans_prob=0.7)
         fig_pred.savefig('./tests/data/pred.pdf')   
 
+        print(metric_hmm)
         self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
         self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
     

+ 161 - 0
backend/train_hmm.py

@@ -0,0 +1,161 @@
+'''
+Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
+'''
+import os
+import argparse
+
+from hmmlearn import hmm
+import numpy as np
+import yaml
+import joblib
+from scipy import signal
+import matplotlib.pyplot as plt
+
+from dataloaders import neo
+import bci_core.utils as bci_utils
+from settings.config import settings
+
+
+config_info = settings.CONFIG_INFO
+
+class HMMClassifier(hmm.BaseHMM):
+    # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
+    def __init__(self, emission_model, **kwargs):
+        n_components = len(emission_model.classes_)
+        super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
+
+        self.emission_model = emission_model
+    
+    def _check_and_set_n_features(self, X):
+        if X.ndim == 2:  # 
+            n_features = X.shape[1]
+        elif X.ndim == 3:
+            n_features = X.shape[1] * X.shape[2]
+        else:
+            raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
+        if hasattr(self, "n_features"):
+            if self.n_features != n_features:
+                raise ValueError(
+                    f"Unexpected number of dimensions, got {n_features} but "
+                    f"expected {self.n_features}")
+        else:
+            self.n_features = n_features
+    
+    def _get_n_fit_scalars_per_param(self):
+        nc = self.n_components
+        return {
+            "s": nc,
+            "t": nc ** 2}
+        
+    def _compute_likelihood(self, X):
+        p = self.emission_model.predict_proba(X)
+        return p
+
+
+def extract_baseline_feature(model, raw, step):
+    fs = raw.info['sfreq']
+    feat_extractor, _ = model
+    filter_bank_data = feat_extractor.transform(raw.get_data())
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
+    filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
+    # decimate
+    decimate_rate = np.sqrt(fs / 10).astype(np.int16)
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    # to 10Hz
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    return filter_bank_epoch
+
+
+def extract_riemann_feature(model, raw, step):
+    fs = raw.info['sfreq']
+    feat_extractor, scaler, cov_model, _ = model
+    filtered_data = feat_extractor.transform(raw.get_data())
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
+    X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
+    X = scaler.transform(X)
+    X_cov = cov_model.transform(X)
+    return X_cov
+
+
+def _split_continuous(time_range, step, fs):
+    return np.arange(int(time_range[0] * fs), 
+                           int(time_range[-1] * fs), 
+                           int(step * fs), dtype=np.int64)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--state-change-threshold',
+        '-scth',
+        dest='state_change_threshold',
+        help='Threshold for HMM state change',
+        default=0.75,
+        type=float
+    )
+    parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
+        '--model-filename',
+        dest='model_filename',
+        help='Model filename',
+        default=None,
+        type=str
+    )
+    return parser.parse_args()
+
+args = parse_args()
+# load model and fit hmm
+subj_name = args.subj
+model_filename = args.model_filename
+
+data_dir = f'./data/{subj_name}/'
+    
+model_path = f'./static/models/{subj_name}/{model_filename}'
+
+# load model
+model_type, _ = bci_utils.parse_model_type(model_filename)
+model = joblib.load(model_path)
+
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
+    info = yaml.safe_load(f)
+sessions = info['hmm_sessions']
+
+raw = neo.raw_loader(data_dir, sessions, True)
+
+# cut into buffer len epochs
+if model_type == 'baseline':
+    feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
+elif model_type == 'riemann':
+    feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
+else:
+    raise ValueError
+
+# initiate hmm model
+hmm_model = HMMClassifier(model[-1], n_iter=100)
+hmm_model.fit(feature)
+
+# decode
+log_probs, state_seqs = hmm_model.decode(feature)
+plt.figure()
+plt.plot(state_seqs)
+
+# save transmat
+np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
+
+plt.show()

+ 1 - 1
backend/training.py

@@ -148,7 +148,7 @@ if __name__ == '__main__':
     
     trial_duration = config_info['buffer_length']
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw = neo.raw_loader(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)

+ 2 - 1
backend/validation.py

@@ -85,6 +85,7 @@ def _val_by_epochs_baseline(raw, events, model_path, duration):
 
 
 def _val_by_epochs_riemann(raw, events, model_path, duration):
+    fs = raw.info['sfreq']
     feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
     filtered_data = feat_extractor.transform(raw.get_data())
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
@@ -113,7 +114,7 @@ if __name__ == '__main__':
     # preprocess raw
     trial_time = 5.
     upsampled_trial_duration = config_info['buffer_length']
-    raw = neo.raw_preprocessing(data_dir, sessions, 
+    raw = neo.raw_loader(data_dir, sessions, 
                                 unify_label=True, 
                                 ori_epoch_length=trial_time, 
                                 mov_trial_ind=[2],