Browse Source

增加trans prob输入

dk 1 year ago
parent
commit
299f52e8e4

+ 8 - 7
.vscode/launch.json

@@ -14,12 +14,11 @@
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
             "--n-trials", "15", 
             "--n-trials", "15", 
-            // "--hand-feedback",
+            "--hand-feedback",
             "--com", "COM3", 
             "--com", "COM3", 
             "-fm", "flex", 
             "-fm", "flex", 
             "-vfr", "0.", 
             "-vfr", "0.", 
-            "-scth", "0.75",
-            "--difficulty", "hard",
+            "--difficulty", "mid",
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
         },
         },
         {
         {
@@ -32,7 +31,8 @@
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
             "--com", "COM3", 
             "--com", "COM3", 
-            "-scth", "0.75",
+            "-scth", "0.9",
+            "-stp", "0.9",
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
             "--model-path", "./static/models/XW01/riemann_rest+flex_12-05-2023-19-10-25.pkl"]
         },
         },
         {
         {
@@ -67,7 +67,7 @@
             "cwd": "${workspaceFolder}/backend",
             "cwd": "${workspaceFolder}/backend",
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
-            "--model-filename", "riemann_rest+flex_12-05-2023-19-10-25.pkl"]
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         },
         {
         {
             "name": "Online simulation",
             "name": "Online simulation",
@@ -78,8 +78,9 @@
             "cwd": "${workspaceFolder}/backend",
             "cwd": "${workspaceFolder}/backend",
             "justMyCode": true,
             "justMyCode": true,
             "args": ["--subj", "XW01", 
             "args": ["--subj", "XW01", 
-            "-scth", "0.75",
-            "--model-filename", "riemann_rest+flex_12-05-2023-19-10-25.pkl"]
+            "-scth", "0.9",
+            "-stp", "0.9",
+            "--model-filename", "riemann_rest+flex_12-06-2023-17-38-27.pkl"]
         },
         },
         {
         {
             "name": "Python: 当前文件",
             "name": "Python: 当前文件",

+ 7 - 2
backend/bci_core/online.py

@@ -22,15 +22,20 @@ class Controller:
     def __init__(self,
     def __init__(self,
                  virtual_feedback_rate=1., 
                  virtual_feedback_rate=1., 
                  model_path=None,
                  model_path=None,
+                 state_trans_prob=0.8,
                  state_change_threshold=0.6):
                  state_change_threshold=0.6):
         if (model_path is None) or (model_path == 'None'):
         if (model_path is None) or (model_path == 'None'):
             self.real_feedback_model = None
             self.real_feedback_model = None
         else:
         else:
             self.model_type, _ = parse_model_type(model_path)
             self.model_type, _ = parse_model_type(model_path)
             if self.model_type == 'baseline':
             if self.model_type == 'baseline':
-                self.real_feedback_model = BaselineHMM(model_path, state_change_threshold=state_change_threshold)
+                self.real_feedback_model = BaselineHMM(model_path, 
+                state_trans_prob=state_trans_prob,
+                state_change_threshold=state_change_threshold)
             elif self.model_type == 'riemann':
             elif self.model_type == 'riemann':
-                self.real_feedback_model = RiemannHMM(model_path, state_change_threshold=state_change_threshold)
+                self.real_feedback_model = RiemannHMM(model_path, 
+                state_trans_prob=state_trans_prob,
+                state_change_threshold=state_change_threshold)
             else:
             else:
                 raise NotImplementedError
                 raise NotImplementedError
         self.virtual_feedback_rate = virtual_feedback_rate
         self.virtual_feedback_rate = virtual_feedback_rate

File diff suppressed because it is too large
+ 0 - 1
backend/free_grasp.psyexp


+ 9 - 1
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Dec  6 17:54:49 2023
+    on Wed Dec  6 17:58:52 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -76,6 +76,14 @@ def parse_args():
         type=float
         type=float
     )
     )
     parser.add_argument(
     parser.add_argument(
+        '--state-trans-prob',
+        '-stp',
+        dest='state_trans_prob',
+        help='Transition probability for HMM state change',
+        default=0.8,
+        type=float
+    )
+    parser.add_argument(
         '--model-path',
         '--model-path',
         dest='model_path',
         dest='model_path',
         help='Path to model file',
         help='Path to model file',

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 3 - 11
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 十一月 29, 2023, at 12:36
+    on Wed Dec  6 17:50:50 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -93,13 +93,6 @@ def parse_args():
         type=float
         type=float
     )
     )
     parser.add_argument(
     parser.add_argument(
-        '--state-change-threshold',
-        '-scth',
-        dest='state_change_threshold',
-        help='Threshold for HMM state change',
-        type=float
-    )
-    parser.add_argument(
         '--difficulty',
         '--difficulty',
         help='Task difficultys',
         help='Task difficultys',
         type=str
         type=str
@@ -131,8 +124,7 @@ if args.hand_feedback:
 
 
 # build bci controller
 # build bci controller
 controller = Controller(args.virtual_feedback_rate, 
 controller = Controller(args.virtual_feedback_rate, 
-                        args.model_path, 
-                        state_change_threshold=args.state_change_threshold)
+                        args.model_path)
 # Run 'Before Experiment' code from decision
 # Run 'Before Experiment' code from decision
 cnt_threshold_table = {
 cnt_threshold_table = {
     'easy': 3,
     'easy': 3,
@@ -217,7 +209,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\general_grasp_training.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/general_grasp_training.py',
         savePickle=True, saveWideText=True,
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )
     )

Some files were not shown because too many files changed in this diff