Browse Source

训练范式适配模型api;自动载入训练好的transmat

dk 1 year ago
parent
commit
859de2e194

+ 13 - 0
backend/bci_core/online.py

@@ -2,6 +2,7 @@ import joblib
 import numpy as np
 import numpy as np
 import random
 import random
 import logging
 import logging
+import os
 from scipy import signal
 from scipy import signal
 from .utils import parse_model_type
 from .utils import parse_model_type
 
 
@@ -219,6 +220,18 @@ class RiemannHMM(HMMModel):
 
 
 
 
 def model_loader(model_path, **kwargs):
 def model_loader(model_path, **kwargs):
+    """
+    模型如果存在训练好的transmat,会直接load
+    """
+    model_root, model_filename = os.path.dirname(model_path), os.path.basename(model_path)
+    model_name = model_filename.split('.')[0]
+    transmat_path = os.path.join(model_root, model_name + '_transmat.txt')
+    if os.path.isfile(transmat_path):
+        transmat = np.loadtxt(transmat_path)
+    else:
+        transmat = None
+    kwargs['transmat'] = transmat
+
     model_type, _ = parse_model_type(model_path)
     model_type, _ = parse_model_type(model_path)
     if model_type == 'baseline':
     if model_type == 'baseline':
         return BaselineHMM(model_path, **kwargs)
         return BaselineHMM(model_path, **kwargs)

File diff suppressed because it is too large
+ 0 - 0
backend/free_grasp.psyexp


+ 10 - 4
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Dec  6 17:58:52 2023
+    on Tue Dec 12 13:24:05 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -43,7 +43,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 from settings.config import settings
 
 
 
 
@@ -93,9 +93,15 @@ def parse_args():
     return parser.parse_args()
     return parser.parse_args()
 args = parse_args()
 args = parse_args()
 
 
+# load model
+input_kwargs = {
+        'state_trans_prob': args.state_trans_prob,
+        'state_change_threshold': args.state_change_threshold
+    }
+control_model = model_loader(args.model_path, **input_kwargs)
+
 # build bci controller
 # build bci controller
-controller = Controller(0., args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+controller = Controller(0., control_model)
 # Run 'Before Experiment' code from device
 # Run 'Before Experiment' code from device
 # connect neo
 # connect neo
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 4 - 3
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Dec  6 17:50:50 2023
+    on Tue Dec 12 13:08:19 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -42,7 +42,7 @@ from device.data_client import NeuracleDataClient
 from device.trigger_box import TriggerNeuracle
 from device.trigger_box import TriggerNeuracle
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from device.fubo_pneumatic_finger import FuboPneumaticFingerClient
 from settings.config import settings
 from settings.config import settings
-from bci_core.online import Controller
+from bci_core.online import Controller, model_loader
 from settings.config import settings
 from settings.config import settings
 
 
 
 
@@ -123,8 +123,9 @@ if args.hand_feedback:
     hand_device = FuboPneumaticFingerClient({'port': args.com})
     hand_device = FuboPneumaticFingerClient({'port': args.com})
 
 
 # build bci controller
 # build bci controller
+control_model = model_loader(args.model_path)
 controller = Controller(args.virtual_feedback_rate, 
 controller = Controller(args.virtual_feedback_rate, 
-                        args.model_path)
+                        control_model)
 # Run 'Before Experiment' code from decision
 # Run 'Before Experiment' code from decision
 cnt_threshold_table = {
 cnt_threshold_table = {
     'easy': 3,
     'easy': 3,

+ 1 - 4
backend/online_sim.py

@@ -187,9 +187,7 @@ if __name__ == '__main__':
 
 
     data_dir = f'./data/{subj_name}/'
     data_dir = f'./data/{subj_name}/'
     
     
-    model_filename = args.model_filename.split('.')[0]
-    model_path = f'./static/models/{subj_name}/{model_filename}.pkl'
-    transmat_path = f'./static/models/{subj_name}/{model_filename}_transmat.txt'
+    model_path = f'./static/models/{subj_name}/{args.model_filename}'
 
 
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
         info = yaml.safe_load(f)
@@ -209,7 +207,6 @@ if __name__ == '__main__':
     
     
     # load model
     # load model
     input_kwargs = {
     input_kwargs = {
-        'transmat': transmat_path if os.path.isfile(transmat_path) else None,
         'state_trans_prob': args.state_trans_prob,
         'state_trans_prob': args.state_trans_prob,
         'state_change_threshold': args.state_change_threshold
         'state_change_threshold': args.state_change_threshold
     }
     }

Some files were not shown because too many files changed in this diff