Browse Source

训练范式适配模型api;自动载入训练好的transmat

dk 1 year ago
parent
commit
859de2e194

+ 13 - 0
backend/bci_core/online.py

@@ -2,6 +2,7 @@ import joblib
 import numpy as np
 import random
 import logging
+import os
 from scipy import signal
 from .utils import parse_model_type
 
@@ -219,6 +220,18 @@ class RiemannHMM(HMMModel):
 
 
 def model_loader(model_path, **kwargs):
+    """
+    模型如果存在训练好的transmat,会直接load
+    """
+    model_root, model_filename = os.path.dirname(model_path), os.path.basename(model_path)
+    model_name = model_filename.split('.')[0]
+    transmat_path = os.path.join(model_root, model_name + '_transmat.txt')
+    if os.path.isfile(transmat_path):
+        transmat = np.loadtxt(transmat_path)
+    else:
+        transmat = None
+    kwargs['transmat'] = transmat
+
     model_type, _ = parse_model_type(model_path)
     if model_type == 'baseline':
         return BaselineHMM(model_path, **kwargs)

File diff suppressed because it is too large
+ 0 - 0
backend/free_grasp.psyexp


+ 10 - 4
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Dec  6 17:58:52 2023
+    on Tue Dec 12 13:24:05 2023
 If you publish work using this script the most relevant publication is:
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -43,7 +43,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 
 
@@ -93,9 +93,15 @@ def parse_args():
     return parser.parse_args()
 args = parse_args()
 
+# load model
+input_kwargs = {
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+control_model = model_loader(args.model_path, **input_kwargs)
+
 # build bci controller
-controller = Controller(0., args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+controller = Controller(0., control_model)
 # Run 'Before Experiment' code from device
 # connect neo
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 4 - 3
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Dec  6 17:50:50 2023
+    on Tue Dec 12 13:08:19 2023
 If you publish work using this script the most relevant publication is:
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -42,7 +42,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 
 
@@ -123,8 +123,9 @@ if args.hand_feedback:
     hand_device = FuboPneumaticFingerClient({'port': args.com})
 
 # build bci controller
+control_model = model_loader(args.model_path)
 controller = Controller(args.virtual_feedback_rate, 
-                        args.model_path)
+                        control_model)
 # Run 'Before Experiment' code from decision
 cnt_threshold_table = {
     'easy': 3,

+ 1 - 4
backend/online_sim.py

@@ -187,9 +187,7 @@ if __name__ == '__main__':
 
     data_dir = f'./data/{subj_name}/'
     
-    model_filename = args.model_filename.split('.')[0]
-    model_path = f'./static/models/{subj_name}/{model_filename}.pkl'
-    transmat_path = f'./static/models/{subj_name}/{model_filename}_transmat.txt'
+    model_path = f'./static/models/{subj_name}/{args.model_filename}'
 
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
@@ -209,7 +207,6 @@ if __name__ == '__main__':
     
     # load model
     input_kwargs = {
-        'transmat': transmat_path if os.path.isfile(transmat_path) else None,
         'state_trans_prob': args.state_trans_prob,
         'state_change_threshold': args.state_change_threshold
     }

Some files were not shown because too many files changed in this diff