Browse Source

Refactor: 统一使用settings中的路径设置,psychopy程序和离线脚本统一模型路径输入方式

dk 1 year ago
parent
commit
37e6c73f64

+ 2 - 2
.vscode/launch.json

@@ -19,7 +19,7 @@
             "-fm", "flex", 
             "-fm", "flex", 
             "-vfr", "0.", 
             "-vfr", "0.", 
             "--difficulty", "mid",
             "--difficulty", "mid",
-            "--model-path", "./static/models/XW01/riemann_rest+flex_01-02-2024-12-04-19.pkl"]
+            "--model-filename", "riemann_rest+flex_01-02-2024-12-04-19.pkl"]
         },
         },
         {
         {
             "name": "Grasp training",
             "name": "Grasp training",
@@ -46,7 +46,7 @@
             "-scth", "0.5",
             "-scth", "0.5",
             "-stp", "0.9",
             "-stp", "0.9",
             "--momentum", "0.7",
             "--momentum", "0.7",
-            "--model-path", "./static/models/XW01/riemann_rest+flex_01-02-2024-12-04-19.pkl"]
+            "--model-filename", "riemann_rest+flex_01-02-2024-12-04-19.pkl"]
         },
         },
         {
         {
             "name": "Band selection",
             "name": "Band selection",

+ 1 - 1
backend/band_selection.py

@@ -44,7 +44,7 @@ args = parse_args()
 
 
 subj_name = args.subj
 subj_name = args.subj
 
 
-data_dir = f'./data/{subj_name}/'
+data_dir = os.path.join(settings.DATA_PATH, subj_name)
 
 
 with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
 with open(os.path.join(data_dir, 'train_info.yml'), 'r') as f:
     info = yaml.safe_load(f)
     info = yaml.safe_load(f)

File diff suppressed because it is too large
+ 0 - 0
backend/free_grasp.psyexp


+ 6 - 5
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on Wed Jan  3 16:56:09 2024
+    on Thu Jan  4 13:13:01 2024
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -90,9 +90,9 @@ def parse_args():
         type=float
         type=float
     )
     )
     parser.add_argument(
     parser.add_argument(
-        '--model-path',
-        dest='model_path',
-        help='Path to model file',
+        '--model-filename',
+        dest='model_filename',
+        help='Model file',
         default=None,
         default=None,
         type=str
         type=str
     )
     )
@@ -100,12 +100,13 @@ def parse_args():
 args = parse_args()
 args = parse_args()
 
 
 # load model
 # load model
+model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
 input_kwargs = {
 input_kwargs = {
         'state_trans_prob': args.state_trans_prob,
         'state_trans_prob': args.state_trans_prob,
         'state_change_threshold': args.state_change_threshold,
         'state_change_threshold': args.state_change_threshold,
         'momentum': args.momentum
         'momentum': args.momentum
     }
     }
-control_model = model_loader(args.model_path, **input_kwargs)
+control_model = model_loader(model_path, **input_kwargs)
 
 
 # build bci controller
 # build bci controller
 controller = Controller(0., control_model, reref_method=config_info['reref'])
 controller = Controller(0., control_model, reref_method=config_info['reref'])

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 7 - 6
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 一月 02, 2024, at 17:07
+    on Thu Jan  4 13:11:53 2024
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -98,9 +98,9 @@ def parse_args():
         type=str
         type=str
     )
     )
     parser.add_argument(
     parser.add_argument(
-        '--model-path',
-        dest='model_path',
-        help='Path to model file',
+        '--model-filename',
+        dest='model_filename',
+        help='Model file',
         default=None,
         default=None,
         type=str
         type=str
     )
     )
@@ -123,7 +123,8 @@ if args.hand_feedback:
     hand_device = FuboPneumaticFingerClient({'port': args.com})
     hand_device = FuboPneumaticFingerClient({'port': args.com})
 
 
 # build bci controller
 # build bci controller
-control_model = model_loader(args.model_path)
+model_path = os.path.join(settings.MODEL_PATH, args.subj, args.model_filename)
+control_model = model_loader(model_path)
 controller = Controller(args.virtual_feedback_rate, 
 controller = Controller(args.virtual_feedback_rate, 
                         control_model, reref_method=config_info['reref'])
                         control_model, reref_method=config_info['reref'])
 # Run 'Before Experiment' code from decision
 # Run 'Before Experiment' code from decision
@@ -210,7 +211,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\general_grasp_training.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/general_grasp_training.py',
         savePickle=True, saveWideText=True,
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )
     )

+ 2 - 3
backend/online_sim.py

@@ -199,9 +199,8 @@ if __name__ == '__main__':
     args = parse_args()
     args = parse_args()
     subj_name = args.subj
     subj_name = args.subj
 
 
-    data_dir = f'./data/{subj_name}/'
-    
-    model_path = f'./static/models/{subj_name}/{args.model_filename}'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
 
 
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
         info = yaml.safe_load(f)

+ 1 - 0
backend/settings/config.py

@@ -51,6 +51,7 @@ class Settings:
     }
     }
     PROJECT_VERSION: str = '0.0.1'
     PROJECT_VERSION: str = '0.0.1'
     DATA_PATH = './data'
     DATA_PATH = './data'
+    MODEL_PATH = './static/models'
 
 
 
 
 settings = Settings()
 settings = Settings()

+ 2 - 3
backend/train_hmm.py

@@ -112,9 +112,8 @@ if __name__ == '__main__':
     subj_name = args.subj
     subj_name = args.subj
     model_filename = args.model_filename
     model_filename = args.model_filename
 
 
-    data_dir = f'./data/{subj_name}/'
-        
-    model_path = f'./static/models/{subj_name}/{model_filename}'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_path = os.path.join(settings.MODEL_PATH, subj_name, model_filename)
 
 
     # load model
     # load model
     model = joblib.load(model_path)
     model = joblib.load(model_path)

+ 2 - 2
backend/training.py

@@ -84,8 +84,8 @@ if __name__ == '__main__':
     subj_name = args.subj
     subj_name = args.subj
     model_type = args.model_type
     model_type = args.model_type
 
 
-    data_dir = f'./data/{subj_name}/'
-    model_dir = './static/models/'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_dir = settings.MODEL_PATH
 
 
     with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
     with open(os.path.join(data_dir, 'model_config.yml'), 'r') as f:
         model_config = yaml.safe_load(f)[model_type]
         model_config = yaml.safe_load(f)[model_type]

+ 2 - 2
backend/validation.py

@@ -69,9 +69,9 @@ if __name__ == '__main__':
     args = parse_args()
     args = parse_args()
     subj_name = args.subj
     subj_name = args.subj
 
 
-    data_dir = f'./data/{subj_name}/'
+    data_dir = os.path.join(settings.DATA_PATH, subj_name)
+    model_path = os.path.join(settings.MODEL_PATH, subj_name, args.model_filename)
     
     
-    model_path = f'./static/models/{subj_name}/{args.model_filename}'
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
     with open(os.path.join(data_dir, 'val_info.yml'), 'r') as f:
         info = yaml.safe_load(f)
         info = yaml.safe_load(f)
     sessions = info['sessions']
     sessions = info['sessions']

Some files were not shown because too many files changed in this diff