Ver código fonte

Merge branch 'hmm-thresh' of dk/kraken into master

dk 1 ano atrás
pai
commit
5271fe004e

+ 19 - 7
.vscode/launch.json

@@ -14,12 +14,11 @@
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "--n-trials", "15", 
-            // "--hand-feedback",
+            "--hand-feedback",
             "--com", "COM3", 
             "-fm", "flex", 
             "-vfr", "0.", 
-            "-scth", "0.75",
-            "--difficulty", "hard",
+            "--difficulty", "mid",
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
         },
         {
@@ -32,7 +31,8 @@
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "--com", "COM3", 
-            "-scth", "0.75",
+            "-scth", "0.9",
+            "-stp", "0.9",
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
         },
         {
@@ -67,7 +67,7 @@
             "cwd": "${workspaceFolder}/backend",
             "justMyCode": true,
             "args": ["--subj", "XW01", 
-            "--model-filename", "riemann_rest+flex_12-05-2023-19-10-25.pkl"]
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         {
             "name": "Online simulation",
@@ -78,8 +78,20 @@
             "cwd": "${workspaceFolder}/backend",
             "justMyCode": true,
             "args": ["--subj", "XW01", 
-            "-scth", "0.75",
-            "--model-filename", "riemann_rest+flex_12-05-2023-19-10-25.pkl"]
+            "-scth", "0.9",
+            "-stp", "0.9",
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
+        },
+        {
+            "name": "Train hmm",
+            "type": "python",
+            "request": "launch",
+            "program": "train_hmm.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         {
             "name": "Python: 当前文件",

+ 1 - 1
backend/band_selection.py

@@ -55,7 +55,7 @@ for f in sessions.keys():
 
 trial_duration = config_info['buffer_length']
 # preprocess raw
-raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+raw = neo.raw_loader(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
 ###############################################################################
 # Pipeline with a frequency band selection based on the class distinctiveness

+ 43 - 26
backend/bci_core/online.py

@@ -2,11 +2,11 @@ import joblib
 import numpy as np
 import random
 import logging
+import os
 from scipy import signal
-import mne
-from .feature_extractors import filterbank_extractor
 from .utils import parse_model_type
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -21,18 +21,9 @@ class Controller:
     """
     def __init__(self,
                  virtual_feedback_rate=1., 
-                 model_path=None,
-                 state_change_threshold=0.6):
-        if (model_path is None) or (model_path == 'None'):
-            self.real_feedback_model = None
-        else:
-            self.model_type, _ = parse_model_type(model_path)
-            if self.model_type == 'baseline':
-                self.real_feedback_model = BaselineHMM(model_path, state_change_threshold=state_change_threshold)
-            elif self.model_type == 'riemann':
-                self.real_feedback_model = RiemannHMM(model_path, state_change_threshold=state_change_threshold)
-            else:
-                raise NotImplementedError
+                 real_feedback_model=None):
+        
+        self.real_feedback_model = real_feedback_model
         self.virtual_feedback_rate = virtual_feedback_rate
 
     def step_decision(self, data, true_label=None):
@@ -108,25 +99,29 @@ class Controller:
 
 
 class HMMModel:
-    def __init__(self, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.7):
+    def __init__(self, transmat=None, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.5):
         self.n_classes = n_classes
-        self._probability = np.ones(n_classes) / n_classes
-        self._last_state = 0
+        self._probability = np.zeros(n_classes)
+        self.reset_state()
 
         self.state_change_threshold = state_change_threshold
 
-        # TODO: train with daily use data
-        # build state transition matrix
-        self.state_trans_matrix = np.zeros((n_classes, n_classes))
-        # fill diagonal
-        np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
-        # fill 0 -> each state, 
-        self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
-        self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
+        if transmat is None:
+            # build state transition matrix
+            self.state_trans_matrix = np.zeros((n_classes, n_classes))
+            # fill diagonal
+            np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
+            # fill 0 -> each state, 
+            self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
+            self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
+        else:
+            if isinstance(transmat, str):
+                transmat = np.loadtxt(transmat)
+            self.state_trans_matrix = transmat
 
     def reset_state(self):
+        self._probability[0] = 1.
         self._last_state = 0
-        self._probability = np.ones(self.n_classes) / self.n_classes
     
     def set_current_state(self, current_state):
         self._last_state = current_state
@@ -222,3 +217,25 @@ class RiemannHMM(HMMModel):
         # predict proba
         p = self.model.predict_proba(data).squeeze()
         return p
+
+
+def model_loader(model_path, **kwargs):
+    """
+    模型如果存在训练好的transmat,会直接load
+    """
+    model_root, model_filename = os.path.dirname(model_path), os.path.basename(model_path)
+    model_name = model_filename.split('.')[0]
+    transmat_path = os.path.join(model_root, model_name + '_transmat.txt')
+    if os.path.isfile(transmat_path):
+        transmat = np.loadtxt(transmat_path)
+    else:
+        transmat = None
+    kwargs['transmat'] = transmat
+
+    model_type, _ = parse_model_type(model_path)
+    if model_type == 'baseline':
+        return BaselineHMM(model_path, **kwargs)
+    elif model_type == 'riemann':
+        return RiemannHMM(model_path, **kwargs)
+    else:
+        raise ValueError(f'Unexpected model type: {model_type}, expect "baseline" or "riemann"')

+ 16 - 8
backend/dataloaders/neo.py

@@ -14,7 +14,7 @@ FINGERMODEL_IDS = settings.FINGERMODEL_IDS
 CONFIG_INFO = settings.CONFIG_INFO
 
 
-def raw_preprocessing(data_root, session_paths:dict, 
+def raw_loader(data_root, session_paths:dict, 
                       do_rereference=True,
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5, 
@@ -32,7 +32,7 @@ def raw_preprocessing(data_root, session_paths:dict,
         mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
         rest_trial_ind: only used when unify_label == True, 
     """
-    raws_loaded = load_sessions(data_root, session_paths)
+    raws_loaded = load_sessions(data_root, session_paths, do_rereference)
     # process event
     raws = []
     for (finger_model, raw) in raws_loaded:
@@ -63,15 +63,19 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws.load_data()
 
+    return raws
+
+
+def preprocessing(raw, do_rereference=True):
+    raw.load_data()
     if do_rereference:
         # common average
-        raws.set_eeg_reference('average')
+        raw.set_eeg_reference('average')
     # high pass
-    raws = raws.filter(1, None)
+    raw = raw.filter(1, None)
     # filter 50Hz
-    raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
-
-    return raws
+    raw = raw.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
+    return raw
 
 
 def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
@@ -105,7 +109,7 @@ def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind
     return events_final
 
 
-def load_sessions(data_root, session_names: dict):
+def load_sessions(data_root, session_names: dict, do_rereference=True):
     # return raws for different finger models on an interleaved manner
     raw_cnt = sum(len(session_names[k]) for k in session_names)
     raws = []
@@ -124,6 +128,10 @@ def load_sessions(data_root, session_names: dict):
                 # kraken format
                 data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
                 raw = mne.io.read_raw_bdf(data_file)
+            # preprocess raw
+            raw = preprocessing(raw, do_rereference)
+
+            # append list
             raws.append((finger_model, raw))
     return raws  
 

Diferenças do arquivo suprimidas por serem muito extensas
+ 2 - 3
backend/free_grasp.psyexp


+ 23 - 12
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Tue Nov 28 19:17:09 2023
+    on Tue Dec 12 13:24:05 2023
 If you publish work using this script the most relevant publication is:
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -43,7 +43,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 
 
@@ -76,6 +76,14 @@ def parse_args():
         type=float
     )
     parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
         '--model-path',
         dest='model_path',
         help='Path to model file',
@@ -85,9 +93,15 @@ def parse_args():
     return parser.parse_args()
 args = parse_args()
 
+# load model
+input_kwargs = {
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+control_model = model_loader(args.model_path, **input_kwargs)
+
 # build bci controller
-controller = Controller(0., args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+controller = Controller(0., control_model)
 # Run 'Before Experiment' code from device
 # connect neo
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
@@ -600,7 +614,7 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
         
         # --- Run Routine "decision" ---
         routineForceEnded = not continueRoutine
-        while continueRoutine and routineTimer.getTime() < 0.1:
+        while continueRoutine:
             # get current time
             t = routineTimer.getTime()
             tThisFlip = win.getFutureFlipTime(clock=routineTimer)
@@ -631,7 +645,7 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
             # if feedback_bar is stopping this frame...
             if feedback_bar.status == STARTED:
                 # is it time to stop? (based on global clock, using actual start)
-                if tThisFlipGlobal > feedback_bar.tStartRefresh + 0.1-frameTolerance:
+                if tThisFlipGlobal > feedback_bar.tStartRefresh + config_info['buffer_length']-frameTolerance:
                     # keep track of stop time/frame for later
                     feedback_bar.tStop = t  # not accounting for scr refresh
                     feedback_bar.frameNStop = frameN  # exact frame index
@@ -667,11 +681,8 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
             if hasattr(thisComponent, "setAutoDraw"):
                 thisComponent.setAutoDraw(False)
         thisExp.addData('decision.stopped', globalClock.getTime())
-        # using non-slip timing so subtract the expected duration of this Routine (unless ended on request)
-        if routineForceEnded:
-            routineTimer.reset()
-        else:
-            routineTimer.addTime(-0.100000)
+        # the Routine "decision" was not non-slip safe, so reset the non-slip timer
+        routineTimer.reset()
         
         # --- Prepare to start Routine "feedback" ---
         continueRoutine = True
@@ -681,7 +692,7 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
         # state changed
         feedback_bar1.progress = force
         if decision != -1:
-            feedback_time = 5
+            feedback_time = 3
             if not decision:
                 trigger.send_trigger(0)
                 hand_device.extend()

Diferenças do arquivo suprimidas por serem muito extensas
+ 0 - 0
backend/general_grasp_training.psyexp


+ 5 - 12
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 十一月 29, 2023, at 12:36
+    on Tue Dec 12 13:08:19 2023
 If you publish work using this script the most relevant publication is:
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -42,7 +42,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 
 
@@ -93,13 +93,6 @@ def parse_args():
         type=float
     )
     parser.add_argument(
-        '--state-change-threshold',
-        '-scth',
-        dest='state_change_threshold',
-        help='Threshold for HMM state change',
-        type=float
-    )
-    parser.add_argument(
         '--difficulty',
         help='Task difficultys',
         type=str
@@ -130,9 +123,9 @@ if args.hand_feedback:
     hand_device = FuboPneumaticFingerClient({'port': args.com})
 
 # build bci controller
+control_model = model_loader(args.model_path)
 controller = Controller(args.virtual_feedback_rate, 
-                        args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+                        control_model)
 # Run 'Before Experiment' code from decision
 cnt_threshold_table = {
     'easy': 3,
@@ -217,7 +210,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\general_grasp_training.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/general_grasp_training.py',
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )

+ 29 - 16
backend/online_sim.py

@@ -42,6 +42,14 @@ def parse_args():
         type=float
     )
     parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
         '--model-filename',
         dest='model_filename',
         help='Model filename',
@@ -70,16 +78,16 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
 
 
-def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
+def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
     val_data = raw.get_data()
     fs = raw.info['sfreq']
 
-    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
+    data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
 
     decision_with_hmm = []
     decision_without_hmm = []
     probs = []
-    for time, data in data_gen.loop():
+    for time, data in data_gen.loop(step_length):
         step_p, cls = model_hmm.viterbi(data, return_step_p=True)
         if cls >=0:
             cls = model_hmm.model.classes_[cls]
@@ -117,16 +125,16 @@ def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
 
 
 def simulation(raw_val, event_id, model, 
-               state_change_threshold=0.8, 
-               step_length=1., 
+               epoch_length=1., 
+               step_length=0.1,
                event_trial_length=5.):
     """模型验证接口,使用指定数据进行验证,绘制ersd map
     Args:
         raw (mne.io.Raw)
         event_id (dict)
         model: validate existing model, 
-        state_change_threshold (float): default 0.8
-        step_length (float): batch data step length, default 1. (s)
+        epoch_length (float): batch data length, default 1 (s)
+        step_length (float): data step length, default 0.1 (s)
         event_trial_length (float): 
 
     Returns:
@@ -142,16 +150,13 @@ def simulation(raw_val, event_id, model,
                                         rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
                                         mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
                                         use_original_label=True)
-    
-    
-    controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
-    model_hmm = controller.real_feedback_model
 
     # run with and without hmm
     fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
                                                           events_val, 
-                                                          model_hmm, 
-                                                          step_length, 
+                                                          model, 
+                                                          epoch_length, 
+                                                          step_length,
                                                           event_trial_length=event_trial_length)
 
     return metric_hmm, metric_naive, fig_pred
@@ -183,6 +188,7 @@ if __name__ == '__main__':
     data_dir = f'./data/{subj_name}/'
     
     model_path = f'./static/models/{subj_name}/{args.model_filename}'
+
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']
@@ -192,18 +198,25 @@ if __name__ == '__main__':
     
     # preprocess raw
     trial_time = 5.
-    raw = neo.raw_preprocessing(data_dir, sessions, 
+    raw = neo.raw_loader(data_dir, sessions, 
                                 unify_label=True, 
                                 ori_epoch_length=trial_time, 
                                 mov_trial_ind=[2], 
                                 rest_trial_ind=[1], 
                                 upsampled_epoch_length=None)
+    
+    # load model
+    input_kwargs = {
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+    model_hmm = online.model_loader(model_path, **input_kwargs)
 
     # do validations
     metric_hmm, metric_naive, fig_pred = simulation(raw, 
                                              event_id, 
-                                             model=model_path, 
-                                             state_change_threshold=args.state_change_threshold,
+                                             model=model_hmm, 
+                                             epoch_length=config_info['buffer_length'],
                                              step_length=config_info['buffer_length'],
                                              event_trial_length=trial_time)
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   

+ 1 - 1
backend/tests/test_neoloader.py

@@ -10,7 +10,7 @@ class TestDataloader(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw = neo.raw_loader(root_path, sessions, unify_label=True)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
         self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))

+ 6 - 4
backend/tests/test_online.py

@@ -33,7 +33,8 @@ class TestOnline(unittest.TestCase):
         return super().tearDownClass()
     
     def test_step_feedback(self):
-        controller = online.Controller(0, self.model_path)
+        model_hmm = online.model_loader(self.model_path)
+        controller = online.Controller(0, model_hmm)
         rets = []
         for time, data in self.data_gen.loop():
             cls = controller.step_decision(data)
@@ -61,7 +62,8 @@ class TestOnline(unittest.TestCase):
         self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
 
     def test_real_feedback(self):
-        controller = online.Controller(0, self.model_path)
+        model_hmm = online.model_loader(self.model_path)
+        controller = online.Controller(0, model_hmm)
         rets = []
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
             cls = controller.decision(data)
@@ -74,7 +76,7 @@ class TestHMM(unittest.TestCase):
         # binary
         probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
         true_state = [-1, -1, 1, -1, -1, -1, 0]
-        model = online.HMMModel(2, state_trans_prob=0.9, state_change_threshold=0.7)
+        model = online.HMMModel(transmat=None, n_classes=2, state_trans_prob=0.9, state_change_threshold=0.7)
         states = []
         for p in probs:
             cur_state = model.update_state(p)
@@ -84,7 +86,7 @@ class TestHMM(unittest.TestCase):
         # triple
         probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
         true_state = [-1, 1, -1, -1, 0, 2]
-        model = online.HMMModel(3, state_trans_prob=0.9)
+        model = online.HMMModel(transmat=None, n_classes=3, state_trans_prob=0.9)
         states = []
         for p in probs:
             cur_state = model.update_state(p)

+ 1 - 1
backend/tests/test_training.py

@@ -19,7 +19,7 @@ class TestTraining(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw = neo.raw_loader(root_path, sessions, unify_label=True)
         raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
         cls.raw = raw
     

+ 6 - 1
backend/tests/test_validation.py

@@ -5,6 +5,7 @@ from glob import glob
 import shutil
 
 from bci_core import utils as ana_utils
+from bci_core.online import model_loader
 from training import train_model, model_saver
 from dataloaders import library_ieeg
 from online_sim import simulation
@@ -51,9 +52,13 @@ class TestOnlineSim(unittest.TestCase):
         self.assertEqual(recall, 1)
 
     def test_sim(self):
-        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
+        model = model_loader(self.model_path, 
+                             state_change_threshold=0.7,
+                             state_trans_prob=0.7)
+        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=model, epoch_length=1., step_length=0.1)
         fig_pred.savefig('./tests/data/pred.pdf')   
 
+        print(metric_hmm)
         self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
         self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
     

+ 161 - 0
backend/train_hmm.py

@@ -0,0 +1,161 @@
+'''
+Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
+'''
+import os
+import argparse
+
+from hmmlearn import hmm
+import numpy as np
+import yaml
+import joblib
+from scipy import signal
+import matplotlib.pyplot as plt
+
+from dataloaders import neo
+import bci_core.utils as bci_utils
+from settings.config import settings
+
+
+config_info = settings.CONFIG_INFO
+
+class HMMClassifier(hmm.BaseHMM):
+    # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
+    def __init__(self, emission_model, **kwargs):
+        n_components = len(emission_model.classes_)
+        super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
+
+        self.emission_model = emission_model
+    
+    def _check_and_set_n_features(self, X):
+        if X.ndim == 2:  # 
+            n_features = X.shape[1]
+        elif X.ndim == 3:
+            n_features = X.shape[1] * X.shape[2]
+        else:
+            raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
+        if hasattr(self, "n_features"):
+            if self.n_features != n_features:
+                raise ValueError(
+                    f"Unexpected number of dimensions, got {n_features} but "
+                    f"expected {self.n_features}")
+        else:
+            self.n_features = n_features
+    
+    def _get_n_fit_scalars_per_param(self):
+        nc = self.n_components
+        return {
+            "s": nc,
+            "t": nc ** 2}
+        
+    def _compute_likelihood(self, X):
+        p = self.emission_model.predict_proba(X)
+        return p
+
+
+def extract_baseline_feature(model, raw, step):
+    fs = raw.info['sfreq']
+    feat_extractor, _ = model
+    filter_bank_data = feat_extractor.transform(raw.get_data())
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
+    filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
+    # decimate
+    decimate_rate = np.sqrt(fs / 10).astype(np.int16)
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    # to 10Hz
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    return filter_bank_epoch
+
+
+def extract_riemann_feature(model, raw, step):
+    fs = raw.info['sfreq']
+    feat_extractor, scaler, cov_model, _ = model
+    filtered_data = feat_extractor.transform(raw.get_data())
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
+    X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
+    X = scaler.transform(X)
+    X_cov = cov_model.transform(X)
+    return X_cov
+
+
+def _split_continuous(time_range, step, fs):
+    return np.arange(int(time_range[0] * fs), 
+                           int(time_range[-1] * fs), 
+                           int(step * fs), dtype=np.int64)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--state-change-threshold',
+        '-scth',
+        dest='state_change_threshold',
+        help='Threshold for HMM state change',
+        default=0.75,
+        type=float
+    )
+    parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
+        '--model-filename',
+        dest='model_filename',
+        help='Model filename',
+        default=None,
+        type=str
+    )
+    return parser.parse_args()
+
+args = parse_args()
+# load model and fit hmm
+subj_name = args.subj
+model_filename = args.model_filename
+
+data_dir = f'./data/{subj_name}/'
+    
+model_path = f'./static/models/{subj_name}/{model_filename}'
+
+# load model
+model_type, _ = bci_utils.parse_model_type(model_filename)
+model = joblib.load(model_path)
+
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
+    info = yaml.safe_load(f)
+sessions = info['hmm_sessions']
+
+raw = neo.raw_loader(data_dir, sessions, True)
+
+# cut into buffer len epochs
+if model_type == 'baseline':
+    feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
+elif model_type == 'riemann':
+    feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
+else:
+    raise ValueError
+
+# initiate hmm model
+hmm_model = HMMClassifier(model[-1], n_iter=100)
+hmm_model.fit(feature)
+
+# decode
+log_probs, state_seqs = hmm_model.decode(feature)
+plt.figure()
+plt.plot(state_seqs)
+
+# save transmat
+np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
+
+plt.show()

+ 1 - 1
backend/training.py

@@ -148,7 +148,7 @@ if __name__ == '__main__':
     
     trial_duration = config_info['buffer_length']
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw = neo.raw_loader(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)

+ 2 - 1
backend/validation.py

@@ -85,6 +85,7 @@ def _val_by_epochs_baseline(raw, events, model_path, duration):
 
 
 def _val_by_epochs_riemann(raw, events, model_path, duration):
+    fs = raw.info['sfreq']
     feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
     filtered_data = feat_extractor.transform(raw.get_data())
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
@@ -113,7 +114,7 @@ if __name__ == '__main__':
     # preprocess raw
     trial_time = 5.
     upsampled_trial_duration = config_info['buffer_length']
-    raw = neo.raw_preprocessing(data_dir, sessions, 
+    raw = neo.raw_loader(data_dir, sessions, 
                                 unify_label=True, 
                                 ori_epoch_length=trial_time, 
                                 mov_trial_ind=[2], 

+ 1 - 0
environment.yml

@@ -18,3 +18,4 @@ dependencies:
       - scikit-learn~=1.3.2
       - matplotlib~=3.8.1
       - pyyaml~=6.0.1
+      - hmmlearn~=0.3.0

+ 2 - 1
requirements.txt

@@ -10,4 +10,5 @@ pyriemann==0.5
 scikit_learn==1.3.2
 scipy==1.11.3
 opencv_python~=4.8.1
-pyyaml~=6.0.1
+pyyaml~=6.0.1
+hmmlearn~=0.3.0

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff