Browse Source

Merge branch 'hmm-thresh' of dk/kraken into master

dk 1 year ago
parent
commit
5271fe004e

+ 19 - 7
.vscode/launch.json

@@ -14,12 +14,11 @@
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
             "--n-trials", "15", 
             "--n-trials", "15", 
-            // "--hand-feedback",
+            "--hand-feedback",
             "--com", "COM3", 
             "--com", "COM3", 
             "-fm", "flex", 
             "-fm", "flex", 
             "-vfr", "0.", 
             "-vfr", "0.", 
-            "-scth", "0.75",
-            "--difficulty", "hard",
+            "--difficulty", "mid",
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
         },
         },
         {
         {
@@ -32,7 +31,8 @@
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
             "--com", "COM3", 
             "--com", "COM3", 
-            "-scth", "0.75",
+            "-scth", "0.9",
+            "-stp", "0.9",
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
         },
         },
         {
         {
@@ -67,7 +67,7 @@
             "cwd": "${workspaceFolder}/backend",
             "cwd": "${workspaceFolder}/backend",
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
-            "--model-filename", "riemann_rest+flex_12-05-2023-19-10-25.pkl"]
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         },
         {
         {
             "name": "Online simulation",
             "name": "Online simulation",
@@ -78,8 +78,20 @@
             "cwd": "${workspaceFolder}/backend",
             "cwd": "${workspaceFolder}/backend",
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
-            "-scth", "0.75",
-            "--model-filename", "riemann_rest+flex_12-05-2023-19-10-25.pkl"]
+            "-scth", "0.9",
+            "-stp", "0.9",
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
+        },
+        {
+            "name": "Train hmm",
+            "type": "python",
+            "request": "launch",
+            "program": "train_hmm.py",
+            "console": "integratedTerminal",
+            "cwd": "${workspaceFolder}/backend",
+            "justMyCode": true,
+            "args": ["--subj", "XW01", 
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         },
         {
         {
             "name": "Python: 当前文件",
             "name": "Python: 当前文件",

+ 1 - 1
backend/band_selection.py

@@ -55,7 +55,7 @@ for f in sessions.keys():
 
 
 trial_duration = config_info['buffer_length']
 trial_duration = config_info['buffer_length']
 # preprocess raw
 # preprocess raw
-raw = neo.raw_preprocessing(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+raw = neo.raw_loader(data_dir, sessions, do_rereference=False, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
 
 ###############################################################################
 ###############################################################################
 # Pipeline with a frequency band selection based on the class distinctiveness
 # Pipeline with a frequency band selection based on the class distinctiveness

+ 43 - 26
backend/bci_core/online.py

@@ -2,11 +2,11 @@ import joblib
 import numpy as np
 import numpy as np
 import random
 import random
 import logging
 import logging
+import os
 from scipy import signal
 from scipy import signal
-import mne
-from .feature_extractors import filterbank_extractor
 from .utils import parse_model_type
 from .utils import parse_model_type
 
 
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
@@ -21,18 +21,9 @@ class Controller:
     """
     """
     def __init__(self,
     def __init__(self,
                  virtual_feedback_rate=1., 
                  virtual_feedback_rate=1., 
-                 model_path=None,
-                 state_change_threshold=0.6):
-        if (model_path is None) or (model_path == 'None'):
-            self.real_feedback_model = None
-        else:
-            self.model_type, _ = parse_model_type(model_path)
-            if self.model_type == 'baseline':
-                self.real_feedback_model = BaselineHMM(model_path, state_change_threshold=state_change_threshold)
-            elif self.model_type == 'riemann':
-                self.real_feedback_model = RiemannHMM(model_path, state_change_threshold=state_change_threshold)
-            else:
-                raise NotImplementedError
+                 real_feedback_model=None):
+        
+        self.real_feedback_model = real_feedback_model
         self.virtual_feedback_rate = virtual_feedback_rate
         self.virtual_feedback_rate = virtual_feedback_rate
 
 
     def step_decision(self, data, true_label=None):
     def step_decision(self, data, true_label=None):
@@ -108,25 +99,29 @@ class Controller:
 
 
 
 
 class HMMModel:
 class HMMModel:
-    def __init__(self, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.7):
+    def __init__(self, transmat=None, n_classes=2, state_trans_prob=0.6, state_change_threshold=0.5):
         self.n_classes = n_classes
         self.n_classes = n_classes
-        self._probability = np.ones(n_classes) / n_classes
-        self._last_state = 0
+        self._probability = np.zeros(n_classes)
+        self.reset_state()
 
 
         self.state_change_threshold = state_change_threshold
         self.state_change_threshold = state_change_threshold
 
 
-        # TODO: train with daily use data
-        # build state transition matrix
-        self.state_trans_matrix = np.zeros((n_classes, n_classes))
-        # fill diagonal
-        np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
-        # fill 0 -> each state, 
-        self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
-        self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
+        if transmat is None:
+            # build state transition matrix
+            self.state_trans_matrix = np.zeros((n_classes, n_classes))
+            # fill diagonal
+            np.fill_diagonal(self.state_trans_matrix, state_trans_prob)
+            # fill 0 -> each state, 
+            self.state_trans_matrix[0, 1:] = (1 - state_trans_prob) / (n_classes - 1)
+            self.state_trans_matrix[1:, 0] = 1 - state_trans_prob
+        else:
+            if isinstance(transmat, str):
+                transmat = np.loadtxt(transmat)
+            self.state_trans_matrix = transmat
 
 
     def reset_state(self):
     def reset_state(self):
+        self._probability[0] = 1.
         self._last_state = 0
         self._last_state = 0
-        self._probability = np.ones(self.n_classes) / self.n_classes
     
     
     def set_current_state(self, current_state):
     def set_current_state(self, current_state):
         self._last_state = current_state
         self._last_state = current_state
@@ -222,3 +217,25 @@ class RiemannHMM(HMMModel):
         # predict proba
         # predict proba
         p = self.model.predict_proba(data).squeeze()
         p = self.model.predict_proba(data).squeeze()
         return p
         return p
+
+
+def model_loader(model_path, **kwargs):
+    """
+    模型如果存在训练好的transmat,会直接load
+    """
+    model_root, model_filename = os.path.dirname(model_path), os.path.basename(model_path)
+    model_name = model_filename.split('.')[0]
+    transmat_path = os.path.join(model_root, model_name + '_transmat.txt')
+    if os.path.isfile(transmat_path):
+        transmat = np.loadtxt(transmat_path)
+    else:
+        transmat = None
+    kwargs['transmat'] = transmat
+
+    model_type, _ = parse_model_type(model_path)
+    if model_type == 'baseline':
+        return BaselineHMM(model_path, **kwargs)
+    elif model_type == 'riemann':
+        return RiemannHMM(model_path, **kwargs)
+    else:
+        raise ValueError(f'Unexpected model type: {model_type}, expect "baseline" or "riemann"')

+ 16 - 8
backend/dataloaders/neo.py

@@ -14,7 +14,7 @@ FINGERMODEL_IDS = settings.FINGERMODEL_IDS
 CONFIG_INFO = settings.CONFIG_INFO
 CONFIG_INFO = settings.CONFIG_INFO
 
 
 
 
-def raw_preprocessing(data_root, session_paths:dict, 
+def raw_loader(data_root, session_paths:dict, 
                       do_rereference=True,
                       do_rereference=True,
                       upsampled_epoch_length=1., 
                       upsampled_epoch_length=1., 
                       ori_epoch_length=5, 
                       ori_epoch_length=5, 
@@ -32,7 +32,7 @@ def raw_preprocessing(data_root, session_paths:dict,
         mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
         mov_trial_ind: only used when unify_label == True, suggesting the raw file's annotations didn't use unified labels (old pony format)
         rest_trial_ind: only used when unify_label == True, 
         rest_trial_ind: only used when unify_label == True, 
     """
     """
-    raws_loaded = load_sessions(data_root, session_paths)
+    raws_loaded = load_sessions(data_root, session_paths, do_rereference)
     # process event
     # process event
     raws = []
     raws = []
     for (finger_model, raw) in raws_loaded:
     for (finger_model, raw) in raws_loaded:
@@ -63,15 +63,19 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws = mne.concatenate_raws(raws)
     raws.load_data()
     raws.load_data()
 
 
+    return raws
+
+
+def preprocessing(raw, do_rereference=True):
+    raw.load_data()
     if do_rereference:
     if do_rereference:
         # common average
         # common average
-        raws.set_eeg_reference('average')
+        raw.set_eeg_reference('average')
     # high pass
     # high pass
-    raws = raws.filter(1, None)
+    raw = raw.filter(1, None)
     # filter 50Hz
     # filter 50Hz
-    raws = raws.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
-
-    return raws
+    raw = raw.notch_filter([50, 100, 150], trans_bandwidth=3, verbose=False)
+    return raw
 
 
 
 
 def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
 def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind=[2, 3], rest_trial_ind=[4], use_original_label=False):
@@ -105,7 +109,7 @@ def reconstruct_events(events, fs, finger_model, trial_duration=5, mov_trial_ind
     return events_final
     return events_final
 
 
 
 
-def load_sessions(data_root, session_names: dict):
+def load_sessions(data_root, session_names: dict, do_rereference=True):
     # return raws for different finger models on an interleaved manner
     # return raws for different finger models on an interleaved manner
     raw_cnt = sum(len(session_names[k]) for k in session_names)
     raw_cnt = sum(len(session_names[k]) for k in session_names)
     raws = []
     raws = []
@@ -124,6 +128,10 @@ def load_sessions(data_root, session_names: dict):
                 # kraken format
                 # kraken format
                 data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
                 data_file = glob.glob(os.path.join(data_root, s, '*.bdf'))[0]
                 raw = mne.io.read_raw_bdf(data_file)
                 raw = mne.io.read_raw_bdf(data_file)
+            # preprocess raw
+            raw = preprocessing(raw, do_rereference)
+
+            # append list
             raws.append((finger_model, raw))
             raws.append((finger_model, raw))
     return raws  
     return raws  
 
 

File diff suppressed because it is too large
+ 2 - 3
backend/free_grasp.psyexp


+ 23 - 12
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Tue Nov 28 19:17:09 2023
+    on Tue Dec 12 13:24:05 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -43,7 +43,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 from settings.config import settings
 
 
 
 
@@ -76,6 +76,14 @@ def parse_args():
         type=float
         type=float
     )
     )
     parser.add_argument(
     parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
         '--model-path',
         '--model-path',
         dest='model_path',
         dest='model_path',
         help='Path to model file',
         help='Path to model file',
@@ -85,9 +93,15 @@ def parse_args():
     return parser.parse_args()
     return parser.parse_args()
 args = parse_args()
 args = parse_args()
 
 
+# load model
+input_kwargs = {
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+control_model = model_loader(args.model_path, **input_kwargs)
+
 # build bci controller
 # build bci controller
-controller = Controller(0., args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+controller = Controller(0., control_model)
 # Run 'Before Experiment' code from device
 # Run 'Before Experiment' code from device
 # connect neo
 # connect neo
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
@@ -600,7 +614,7 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
         
         
         # --- Run Routine "decision" ---
         # --- Run Routine "decision" ---
         routineForceEnded = not continueRoutine
         routineForceEnded = not continueRoutine
-        while continueRoutine and routineTimer.getTime() < 0.1:
+        while continueRoutine:
             # get current time
             # get current time
             t = routineTimer.getTime()
             t = routineTimer.getTime()
             tThisFlip = win.getFutureFlipTime(clock=routineTimer)
             tThisFlip = win.getFutureFlipTime(clock=routineTimer)
@@ -631,7 +645,7 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
             # if feedback_bar is stopping this frame...
             # if feedback_bar is stopping this frame...
             if feedback_bar.status == STARTED:
             if feedback_bar.status == STARTED:
                 # is it time to stop? (based on global clock, using actual start)
                 # is it time to stop? (based on global clock, using actual start)
-                if tThisFlipGlobal > feedback_bar.tStartRefresh + 0.1-frameTolerance:
+                if tThisFlipGlobal > feedback_bar.tStartRefresh + config_info['buffer_length']-frameTolerance:
                     # keep track of stop time/frame for later
                     # keep track of stop time/frame for later
                     feedback_bar.tStop = t  # not accounting for scr refresh
                     feedback_bar.tStop = t  # not accounting for scr refresh
                     feedback_bar.frameNStop = frameN  # exact frame index
                     feedback_bar.frameNStop = frameN  # exact frame index
@@ -667,11 +681,8 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
             if hasattr(thisComponent, "setAutoDraw"):
             if hasattr(thisComponent, "setAutoDraw"):
                 thisComponent.setAutoDraw(False)
                 thisComponent.setAutoDraw(False)
         thisExp.addData('decision.stopped', globalClock.getTime())
         thisExp.addData('decision.stopped', globalClock.getTime())
-        # using non-slip timing so subtract the expected duration of this Routine (unless ended on request)
-        if routineForceEnded:
-            routineTimer.reset()
-        else:
-            routineTimer.addTime(-0.100000)
+        # the Routine "decision" was not non-slip safe, so reset the non-slip timer
+        routineTimer.reset()
         
         
         # --- Prepare to start Routine "feedback" ---
         # --- Prepare to start Routine "feedback" ---
         continueRoutine = True
         continueRoutine = True
@@ -681,7 +692,7 @@ def run(expInfo, thisExp, win, inputs, globalClock=None, thisSession=None):
         # state changed
         # state changed
         feedback_bar1.progress = force
         feedback_bar1.progress = force
         if decision != -1:
         if decision != -1:
-            feedback_time = 5
+            feedback_time = 3
             if not decision:
             if not decision:
                 trigger.send_trigger(0)
                 trigger.send_trigger(0)
                 hand_device.extend()
                 hand_device.extend()

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 5 - 12
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 十一月 29, 2023, at 12:36
+    on Tue Dec 12 13:08:19 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -42,7 +42,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 from settings.config import settings
 
 
 
 
@@ -93,13 +93,6 @@ def parse_args():
         type=float
         type=float
     )
     )
     parser.add_argument(
     parser.add_argument(
-        '--state-change-threshold',
-        '-scth',
-        dest='state_change_threshold',
-        help='Threshold for HMM state change',
-        type=float
-    )
-    parser.add_argument(
         '--difficulty',
         '--difficulty',
         help='Task difficultys',
         help='Task difficultys',
         type=str
         type=str
@@ -130,9 +123,9 @@ if args.hand_feedback:
     hand_device = FuboPneumaticFingerClient({'port': args.com})
     hand_device = FuboPneumaticFingerClient({'port': args.com})
 
 
 # build bci controller
 # build bci controller
+control_model = model_loader(args.model_path)
 controller = Controller(args.virtual_feedback_rate, 
 controller = Controller(args.virtual_feedback_rate, 
-                        args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+                        control_model)
 # Run 'Before Experiment' code from decision
 # Run 'Before Experiment' code from decision
 cnt_threshold_table = {
 cnt_threshold_table = {
     'easy': 3,
     'easy': 3,
@@ -217,7 +210,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\general_grasp_training.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/general_grasp_training.py',
         savePickle=True, saveWideText=True,
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )
     )

+ 29 - 16
backend/online_sim.py

@@ -42,6 +42,14 @@ def parse_args():
         type=float
         type=float
     )
     )
     parser.add_argument(
     parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
         '--model-filename',
         '--model-filename',
         dest='model_filename',
         dest='model_filename',
         help='Model filename',
         help='Model filename',
@@ -70,16 +78,16 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
             yield i / self.fs, self.get_data_batch(i)
 
 
 
 
-def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
+def _evaluation_loop(raw, events, model_hmm, epoch_length, step_length, event_trial_length):
     val_data = raw.get_data()
     val_data = raw.get_data()
     fs = raw.info['sfreq']
     fs = raw.info['sfreq']
 
 
-    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
+    data_gen = DataGenerator(fs, val_data, epoch_step=epoch_length)
 
 
     decision_with_hmm = []
     decision_with_hmm = []
     decision_without_hmm = []
     decision_without_hmm = []
     probs = []
     probs = []
-    for time, data in data_gen.loop():
+    for time, data in data_gen.loop(step_length):
         step_p, cls = model_hmm.viterbi(data, return_step_p=True)
         step_p, cls = model_hmm.viterbi(data, return_step_p=True)
         if cls >=0:
         if cls >=0:
             cls = model_hmm.model.classes_[cls]
             cls = model_hmm.model.classes_[cls]
@@ -117,16 +125,16 @@ def _evaluation_loop(raw, events, model_hmm, step_length, event_trial_length):
 
 
 
 
 def simulation(raw_val, event_id, model, 
 def simulation(raw_val, event_id, model, 
-               state_change_threshold=0.8, 
-               step_length=1., 
+               epoch_length=1., 
+               step_length=0.1,
                event_trial_length=5.):
                event_trial_length=5.):
     """模型验证接口,使用指定数据进行验证,绘制ersd map
     """模型验证接口,使用指定数据进行验证,绘制ersd map
     Args:
     Args:
         raw (mne.io.Raw)
         raw (mne.io.Raw)
         event_id (dict)
         event_id (dict)
         model: validate existing model, 
         model: validate existing model, 
-        state_change_threshold (float): default 0.8
-        step_length (float): batch data step length, default 1. (s)
+        epoch_length (float): batch data length, default 1 (s)
+        step_length (float): data step length, default 0.1 (s)
         event_trial_length (float): 
         event_trial_length (float): 
 
 
     Returns:
     Returns:
@@ -142,16 +150,13 @@ def simulation(raw_val, event_id, model,
                                         rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
                                         rest_trial_ind=[v for k, v in event_id.items() if k == 'rest'], 
                                         mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
                                         mov_trial_ind=[v for k, v in event_id.items() if k != 'rest'], 
                                         use_original_label=True)
                                         use_original_label=True)
-    
-    
-    controller = online.Controller(0, model, state_change_threshold=state_change_threshold)
-    model_hmm = controller.real_feedback_model
 
 
     # run with and without hmm
     # run with and without hmm
     fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
     fig_pred, metric_hmm, metric_naive = _evaluation_loop(raw_val, 
                                                           events_val, 
                                                           events_val, 
-                                                          model_hmm, 
-                                                          step_length, 
+                                                          model, 
+                                                          epoch_length, 
+                                                          step_length,
                                                           event_trial_length=event_trial_length)
                                                           event_trial_length=event_trial_length)
 
 
     return metric_hmm, metric_naive, fig_pred
     return metric_hmm, metric_naive, fig_pred
@@ -183,6 +188,7 @@ if __name__ == '__main__':
     data_dir = f'./data/{subj_name}/'
     data_dir = f'./data/{subj_name}/'
     
     
     model_path = f'./static/models/{subj_name}/{args.model_filename}'
     model_path = f'./static/models/{subj_name}/{args.model_filename}'
+
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
         info = yaml.safe_load(f)
     sessions = info['sessions']
     sessions = info['sessions']
@@ -192,18 +198,25 @@ if __name__ == '__main__':
     
     
     # preprocess raw
     # preprocess raw
     trial_time = 5.
     trial_time = 5.
-    raw = neo.raw_preprocessing(data_dir, sessions, 
+    raw = neo.raw_loader(data_dir, sessions, 
                                 unify_label=True, 
                                 unify_label=True, 
                                 ori_epoch_length=trial_time, 
                                 ori_epoch_length=trial_time, 
                                 mov_trial_ind=[2], 
                                 mov_trial_ind=[2], 
                                 rest_trial_ind=[1], 
                                 rest_trial_ind=[1], 
                                 upsampled_epoch_length=None)
                                 upsampled_epoch_length=None)
+    
+    # load model
+    input_kwargs = {
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+    model_hmm = online.model_loader(model_path, **input_kwargs)
 
 
     # do validations
     # do validations
     metric_hmm, metric_naive, fig_pred = simulation(raw, 
     metric_hmm, metric_naive, fig_pred = simulation(raw, 
                                              event_id, 
                                              event_id, 
-                                             model=model_path, 
-                                             state_change_threshold=args.state_change_threshold,
+                                             model=model_hmm, 
+                                             epoch_length=config_info['buffer_length'],
                                              step_length=config_info['buffer_length'],
                                              step_length=config_info['buffer_length'],
                                              event_trial_length=trial_time)
                                              event_trial_length=trial_time)
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   

+ 1 - 1
backend/tests/test_neoloader.py

@@ -10,7 +10,7 @@ class TestDataloader(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
         event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw = neo.raw_loader(root_path, sessions, unify_label=True)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, event_id = mne.events_from_annotations(raw, event_id=event_id)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
         events, events_cnt = np.unique(events[:, -1], return_counts=True)
         self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))
         self.assertTrue(np.allclose(events_cnt, (300, 150, 150)))

+ 6 - 4
backend/tests/test_online.py

@@ -33,7 +33,8 @@ class TestOnline(unittest.TestCase):
         return super().tearDownClass()
         return super().tearDownClass()
     
     
     def test_step_feedback(self):
     def test_step_feedback(self):
-        controller = online.Controller(0, self.model_path)
+        model_hmm = online.model_loader(self.model_path)
+        controller = online.Controller(0, model_hmm)
         rets = []
         rets = []
         for time, data in self.data_gen.loop():
         for time, data in self.data_gen.loop():
             cls = controller.step_decision(data)
             cls = controller.step_decision(data)
@@ -61,7 +62,8 @@ class TestOnline(unittest.TestCase):
         self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
         self.assertTrue(abs(correct / n_trial - 0.8) < 0.1)
 
 
     def test_real_feedback(self):
     def test_real_feedback(self):
-        controller = online.Controller(0, self.model_path)
+        model_hmm = online.model_loader(self.model_path)
+        controller = online.Controller(0, model_hmm)
         rets = []
         rets = []
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
         for i, (time, data) in zip(range(300), self.data_gen.loop()):
             cls = controller.decision(data)
             cls = controller.decision(data)
@@ -74,7 +76,7 @@ class TestHMM(unittest.TestCase):
         # binary
         # binary
         probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
         probs = [[0.9, 0.1], [0.5, 0.5], [0.09, 0.91], [0.5, 0.5], [0.3, 0.7], [0.7, 0.3], [0.92,0.08]]
         true_state = [-1, -1, 1, -1, -1, -1, 0]
         true_state = [-1, -1, 1, -1, -1, -1, 0]
-        model = online.HMMModel(2, state_trans_prob=0.9, state_change_threshold=0.7)
+        model = online.HMMModel(transmat=None, n_classes=2, state_trans_prob=0.9, state_change_threshold=0.7)
         states = []
         states = []
         for p in probs:
         for p in probs:
             cur_state = model.update_state(p)
             cur_state = model.update_state(p)
@@ -84,7 +86,7 @@ class TestHMM(unittest.TestCase):
         # triple
         # triple
         probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
         probs = [[0.8, 0.1, 0.1], [0.01, 0.91, 0.09], [0.01, 0.08, 0.91], [0.5, 0.2, 0.3], [0.9, 0.05, 0.02], [0.01, 0.01, 0.98]]
         true_state = [-1, 1, -1, -1, 0, 2]
         true_state = [-1, 1, -1, -1, 0, 2]
-        model = online.HMMModel(3, state_trans_prob=0.9)
+        model = online.HMMModel(transmat=None, n_classes=3, state_trans_prob=0.9)
         states = []
         states = []
         for p in probs:
         for p in probs:
             cur_state = model.update_state(p)
             cur_state = model.update_state(p)

+ 1 - 1
backend/tests/test_training.py

@@ -19,7 +19,7 @@ class TestTraining(unittest.TestCase):
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         sessions = {'cylinder': ['eeg-data/2', 'eeg-data/6'], 'ball': ['eeg-data/4', 'eeg-data/5']}
         cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
         cls.event_id = {'rest': 0, 'cylinder': 1, 'ball': 2}
 
 
-        raw = neo.raw_preprocessing(root_path, sessions, unify_label=True)
+        raw = neo.raw_loader(root_path, sessions, unify_label=True)
         raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
         raw.drop_channels(['T3', 'T4', 'A1', 'A2', 'T5', 'T6', 'M1', 'M2', 'Fp1', 'Fp2', 'F7', 'F8', 'O1', 'Oz', 'O2', 'F3', 'F4', 'Fz'])
         cls.raw = raw
         cls.raw = raw
     
     

+ 6 - 1
backend/tests/test_validation.py

@@ -5,6 +5,7 @@ from glob import glob
 import shutil
 import shutil
 
 
 from bci_core import utils as ana_utils
 from bci_core import utils as ana_utils
+from bci_core.online import model_loader
 from training import train_model, model_saver
 from training import train_model, model_saver
 from dataloaders import library_ieeg
 from dataloaders import library_ieeg
 from online_sim import simulation
 from online_sim import simulation
@@ -51,9 +52,13 @@ class TestOnlineSim(unittest.TestCase):
         self.assertEqual(recall, 1)
         self.assertEqual(recall, 1)
 
 
     def test_sim(self):
     def test_sim(self):
-        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
+        model = model_loader(self.model_path, 
+                             state_change_threshold=0.7,
+                             state_trans_prob=0.7)
+        metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=model, epoch_length=1., step_length=0.1)
         fig_pred.savefig('./tests/data/pred.pdf')   
         fig_pred.savefig('./tests/data/pred.pdf')   
 
 
+        print(metric_hmm)
         self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
         self.assertTrue(metric_hmm[-2] > 0.9)  # f1-score (with hmm)
         self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
         self.assertTrue(metric_nohmm[-2] < 0.5)  # f1-score (without hmm)
     
     

+ 161 - 0
backend/train_hmm.py

@@ -0,0 +1,161 @@
+'''
+Use trained classifier as emission model, train HMM transfer matrix on free grasping tasks
+'''
+import os
+import argparse
+
+from hmmlearn import hmm
+import numpy as np
+import yaml
+import joblib
+from scipy import signal
+import matplotlib.pyplot as plt
+
+from dataloaders import neo
+import bci_core.utils as bci_utils
+from settings.config import settings
+
+
+config_info = settings.CONFIG_INFO
+
+class HMMClassifier(hmm.BaseHMM):
+    # TODO: how to bypass sklearn.check_array, currently I modified the src of hmmlearn (remove all the check_array)
+    def __init__(self, emission_model, **kwargs):
+        n_components = len(emission_model.classes_)
+        super(HMMClassifier, self).__init__(n_components=n_components, params='t', init_params='st', **kwargs)
+
+        self.emission_model = emission_model
+    
+    def _check_and_set_n_features(self, X):
+        if X.ndim == 2:  # 
+            n_features = X.shape[1]
+        elif X.ndim == 3:
+            n_features = X.shape[1] * X.shape[2]
+        else:
+            raise ValueError(f'Unexpected data dimension, got {X.ndim} but expected 2 or 3')
+        if hasattr(self, "n_features"):
+            if self.n_features != n_features:
+                raise ValueError(
+                    f"Unexpected number of dimensions, got {n_features} but "
+                    f"expected {self.n_features}")
+        else:
+            self.n_features = n_features
+    
+    def _get_n_fit_scalars_per_param(self):
+        nc = self.n_components
+        return {
+            "s": nc,
+            "t": nc ** 2}
+        
+    def _compute_likelihood(self, X):
+        p = self.emission_model.predict_proba(X)
+        return p
+
+
+def extract_baseline_feature(model, raw, step):
+    fs = raw.info['sfreq']
+    feat_extractor, _ = model
+    filter_bank_data = feat_extractor.transform(raw.get_data())
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
+    filter_bank_epoch = bci_utils.cut_epochs((0, step, fs), filter_bank_data, timestamps)
+    # decimate
+    decimate_rate = np.sqrt(fs / 10).astype(np.int16)
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    # to 10Hz
+    filter_bank_epoch = signal.decimate(filter_bank_epoch, decimate_rate, axis=-1, zero_phase=True)
+    return filter_bank_epoch
+
+
+def extract_riemann_feature(model, raw, step):
+    fs = raw.info['sfreq']
+    feat_extractor, scaler, cov_model, _ = model
+    filtered_data = feat_extractor.transform(raw.get_data())
+    timestamps = _split_continuous((raw.times[0], raw.times[-1]), step, fs)
+    X = bci_utils.cut_epochs((0, step, fs), filtered_data, timestamps)
+    X = scaler.transform(X)
+    X_cov = cov_model.transform(X)
+    return X_cov
+
+
+def _split_continuous(time_range, step, fs):
+    return np.arange(int(time_range[0] * fs), 
+                           int(time_range[-1] * fs), 
+                           int(step * fs), dtype=np.int64)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description='Model validation'
+    )
+    parser.add_argument(
+        '--subj',
+        dest='subj',
+        help='Subject name',
+        default=None,
+        type=str
+    )
+    parser.add_argument(
+        '--state-change-threshold',
+        '-scth',
+        dest='state_change_threshold',
+        help='Threshold for HMM state change',
+        default=0.75,
+        type=float
+    )
+    parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
+        '--model-filename',
+        dest='model_filename',
+        help='Model filename',
+        default=None,
+        type=str
+    )
+    return parser.parse_args()
+
+args = parse_args()
+# load model and fit hmm
+subj_name = args.subj
+model_filename = args.model_filename
+
+data_dir = f'./data/{subj_name}/'
+    
+model_path = f'./static/models/{subj_name}/{model_filename}'
+
+# load model
+model_type, _ = bci_utils.parse_model_type(model_filename)
+model = joblib.load(model_path)
+
+with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
+    info = yaml.safe_load(f)
+sessions = info['hmm_sessions']
+
+raw = neo.raw_loader(data_dir, sessions, True)
+
+# cut into buffer len epochs
+if model_type == 'baseline':
+    feature = extract_baseline_feature(model, raw, config_info['buffer_length'])
+elif model_type == 'riemann':
+    feature = extract_riemann_feature(model, raw, config_info['buffer_length'])
+else:
+    raise ValueError
+
+# initiate hmm model
+hmm_model = HMMClassifier(model[-1], n_iter=100)
+hmm_model.fit(feature)
+
+# decode
+log_probs, state_seqs = hmm_model.decode(feature)
+plt.figure()
+plt.plot(state_seqs)
+
+# save transmat
+np.savetxt(f'./static/models/{subj_name}/{model_filename.split(".")[0]}_transmat.txt', hmm_model.transmat_)
+
+plt.show()

+ 1 - 1
backend/training.py

@@ -148,7 +148,7 @@ if __name__ == '__main__':
     
     
     trial_duration = config_info['buffer_length']
     trial_duration = config_info['buffer_length']
     # preprocess raw
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw = neo.raw_loader(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
 
     # train model
     # train model
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)
     model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration, **model_config)

+ 2 - 1
backend/validation.py

@@ -85,6 +85,7 @@ def _val_by_epochs_baseline(raw, events, model_path, duration):
 
 
 
 
 def _val_by_epochs_riemann(raw, events, model_path, duration):
 def _val_by_epochs_riemann(raw, events, model_path, duration):
+    fs = raw.info['sfreq']
     feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
     feat_extractor, scaler, cov_model, riemann_model = joblib.load(model_path)
     filtered_data = feat_extractor.transform(raw.get_data())
     filtered_data = feat_extractor.transform(raw.get_data())
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
@@ -113,7 +114,7 @@ if __name__ == '__main__':
     # preprocess raw
     # preprocess raw
     trial_time = 5.
     trial_time = 5.
     upsampled_trial_duration = config_info['buffer_length']
     upsampled_trial_duration = config_info['buffer_length']
-    raw = neo.raw_preprocessing(data_dir, sessions, 
+    raw = neo.raw_loader(data_dir, sessions, 
                                 unify_label=True, 
                                 unify_label=True, 
                                 ori_epoch_length=trial_time, 
                                 ori_epoch_length=trial_time, 
                                 mov_trial_ind=[2], 
                                 mov_trial_ind=[2], 

+ 1 - 0
environment.yml

@@ -18,3 +18,4 @@ dependencies:
       - scikit-learn~=1.3.2
       - scikit-learn~=1.3.2
       - matplotlib~=3.8.1
       - matplotlib~=3.8.1
       - pyyaml~=6.0.1
       - pyyaml~=6.0.1
+      - hmmlearn~=0.3.0

+ 2 - 1
requirements.txt

@@ -10,4 +10,5 @@ pyriemann==0.5
 scikit_learn==1.3.2
 scikit_learn==1.3.2
 scipy==1.11.3
 scipy==1.11.3
 opencv_python~=4.8.1
 opencv_python~=4.8.1
-pyyaml~=6.0.1
+pyyaml~=6.0.1
+hmmlearn~=0.3.0

Some files were not shown because too many files changed in this diff