Browse Source

Refactor: 统一使用settings中的路径设置,psychopy程序和离线脚本统一模型路径输入方式

dk 1 year ago
parent
commit
37e6c73f64

+ 2 - 2
.vscode/launch.json

@@ -19,7 +19,7 @@
             "-fm", "flex", 
             "-vfr", "0.", 
             "--difficulty", "mid",
-            "--model-path", "./static/models/XW01/riemann_rest+flex_01-02-2024-12-04-19.pkl"]
+            "--model-filename", "riemann_rest+flex_01-02-2024-12-04-19.pkl"]
         },
         {
             "name": "Grasp training",
@@ -46,7 +46,7 @@
             "-scth", "0.5",
             "-stp", "0.9",
             "--momentum", "0.7",
-            "--model-path", "./static/models/XW01/riemann_rest+flex_01-02-2024-12-04-19.pkl"]
+            "--model-filename", "riemann_rest+flex_01-02-2024-12-04-19.pkl"]
         },
         {
             "name": "Band selection",

+ 1 - 1
backend/band_selection.py

@@ -44,7 +44,7 @@ args = parse_args()
 
 subj_name = args.subj
 
-data_dir = f'./data/{subj_name}/'
+data_dir = os.path.join(settings.DATA_PATH, subj_name)
 
 with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
     info = yaml.safe_load(f)

File diff suppressed because it is too large
+ 0 - 0
backend/free_grasp.psyexp


+ 6 - 5
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Jan  3 16:56:09 2024
+    on Thu Jan  4 13:13:01 2024
 If you publish work using this script the most relevant publication is:
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -90,9 +90,9 @@ def parse_args():
         type=float
     )
     parser.add_argument(
-        '--model-path',
-        dest='model_path',
-        help='Path to model file',
+        '--model-filename',
+        dest='model_filename',
+        help='Model file',
         default=None,
         type=str
     )
@@ -100,12 +100,13 @@ def parse_args():
 args = parse_args()
 
 # load model
+model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
 input_kwargs = {
         'state_trans_prob': args.state_trans_prob,
         'state_change_threshold': args.state_change_threshold,
         'momentum': args.momentum
     }
-control_model = model_loader(args.model_path, **input_kwargs)
+control_model = model_loader(model_path, **input_kwargs)
 
 # build bci controller
 controller = Controller(0., control_model, reref_method=config_info['reref'])

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 7 - 6
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 一月 02, 2024, at 17:07
+    on Thu Jan  4 13:11:53 2024
 If you publish work using this script the most relevant publication is:
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -98,9 +98,9 @@ def parse_args():
         type=str
     )
     parser.add_argument(
-        '--model-path',
-        dest='model_path',
-        help='Path to model file',
+        '--model-filename',
+        dest='model_filename',
+        help='Model file',
         default=None,
         type=str
     )
@@ -123,7 +123,8 @@ if args.hand_feedback:
     hand_device = FuboPneumaticFingerClient({'port': args.com})
 
 # build bci controller
-control_model = model_loader(args.model_path)
+model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
+control_model = model_loader(model_path)
 controller = Controller(args.virtual_feedback_rate, 
                         control_model, reref_method=config_info['reref'])
 # Run 'Before Experiment' code from decision
@@ -210,7 +211,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\general_grasp_training.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/general_grasp_training.py',
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )

+ 2 - 3
backend/online_sim.py

@@ -199,9 +199,8 @@ if __name__ == '__main__':
     args = parse_args()
     subj_name = args.subj
 
-    data_dir = f'./data/{subj_name}/'
-    
-    model_path = f'./static/models/{subj_name}/{args.model_filename}'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
 
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)

+ 1 - 0
backend/settings/config.py

@@ -51,6 +51,7 @@ class Settings:
     }
     PROJECT_VERSION: str = '0.0.1'
     DATA_PATH = './data'
+    MODEL_PATH = './static/models'
 
 
 settings = Settings()

+ 2 - 3
backend/train_hmm.py

@@ -112,9 +112,8 @@ if __name__ == '__main__':
     subj_name = args.subj
     model_filename = args.model_filename
 
-    data_dir = f'./data/{subj_name}/'
-        
-    model_path = f'./static/models/{subj_name}/{model_filename}'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_path = os.path.join(settings.MODEL_PATH, subj_name, model_filename)
 
     # load model
     model = joblib.load(model_path)

+ 2 - 2
backend/training.py

@@ -84,8 +84,8 @@ if __name__ == '__main__':
     subj_name = args.subj
     model_type = args.model_type
 
-    data_dir = f'./data/{subj_name}/'
-    model_dir = './static/models/'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_dir = settings.MODEL_PATH
 
     with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
         model_config = yaml.safe_load(f)[model_type]

+ 2 - 2
backend/validation.py

@@ -69,9 +69,9 @@ if __name__ == '__main__':
     args = parse_args()
     subj_name = args.subj
 
-    data_dir = f'./data/{subj_name}/'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
     
-    model_path = f'./static/models/{subj_name}/{args.model_filename}'
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
     sessions = info['sessions']

Some files were not shown because too many files changed in this diff