Browse Source

Merge branch 'riemann' of dk/kraken into master

dk 1 year ago
parent
commit
f1a7ae4870

+ 1 - 1
backend/bci_core/feature_extractors.py

@@ -84,7 +84,7 @@ class LFPExtractor:
         """
         """
         lfp_data = []
         lfp_data = []
         for b in self.lfb_bands:
         for b in self.lfb_bands:
-            band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero')
+            band_data = filter.filter_data(data, self.sfreq, b[0], b[1], method='iir', phase='zero', verbose=False)
             lfp_data.append(band_data)
             lfp_data.append(band_data)
         lfp_data = np.concatenate(lfp_data, axis=0)
         lfp_data = np.concatenate(lfp_data, axis=0)
         return lfp_data
         return lfp_data

+ 1 - 1
backend/device/data_client.py

@@ -9,7 +9,7 @@ class NeuracleDataClient:
     UPDATE_INTERVAL = 0.04
     UPDATE_INTERVAL = 0.04
     BYTES_PER_NUM = 4
     BYTES_PER_NUM = 4
 
 
-    def __init__(self, n_channel=9, samplerate=1000, host='localhost', port=8712, buffer_len=1):
+    def __init__(self, n_channel=9, samplerate=1000, host='localhost', port=8712, buffer_len=1.):
         self.n_channel = n_channel
         self.n_channel = n_channel
         self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         self.chunk_size = int(self.UPDATE_INTERVAL * samplerate * self.BYTES_PER_NUM * n_channel)
         self.chunk_size = int(self.UPDATE_INTERVAL * samplerate * self.BYTES_PER_NUM * n_channel)

+ 1 - 1
backend/free_grasp.psyexp

@@ -269,7 +269,7 @@
         <Param val="parameter_inputs" valType="code" updates="None" name="name"/>
         <Param val="parameter_inputs" valType="code" updates="None" name="name"/>
       </CodeComponent>
       </CodeComponent>
       <CodeComponent name="device" plugin="None">
       <CodeComponent name="device" plugin="None">
-        <Param val="# connect neo&amp;#10;receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), &amp;#10;                               samplerate=config_info['sample_rate'],&amp;#10;                               host=config_info['host'],&amp;#10;                               port=config_info['port'])&amp;#10;&amp;#10;# connect to trigger box&amp;#10;trigger = TriggerNeuracle()&amp;#10;&amp;#10;# connect to mechanical hand&amp;#10;hand_device = FuboPneumaticFingerClient({'port': args.com})&amp;#10;" valType="extendedCode" updates="constant" name="Before Experiment"/>
+        <Param val="# connect neo&amp;#10;receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), &amp;#10;                               samplerate=config_info['sample_rate'],&amp;#10;                               host=config_info['host'],&amp;#10;                               port=config_info['port'].&amp;#10;                               buffer_len=config_info['buffer_length'])&amp;#10;&amp;#10;# connect to trigger box&amp;#10;trigger = TriggerNeuracle()&amp;#10;&amp;#10;# connect to mechanical hand&amp;#10;hand_device = FuboPneumaticFingerClient({'port': args.com})&amp;#10;" valType="extendedCode" updates="constant" name="Before Experiment"/>
         <Param val="receiver = new NeuracleDataClient({&quot;n_channel&quot;: config_info[&quot;channel_labels&quot;].length, &quot;samplerate&quot;: config_info[&quot;sample_rate&quot;], &quot;host&quot;: config_info[&quot;host&quot;], &quot;port&quot;: config_info[&quot;port&quot;]});&amp;#10;trigger = new TriggerNeuracle();&amp;#10;hand_device = new FuboPneumaticFingerClient({&quot;port&quot;: args.com});&amp;#10;controller = new Controller(0.0, args.model_path, {&quot;state_change_threshold&quot;: 0.8});&amp;#10;" valType="extendedCode" updates="constant" name="Before JS Experiment"/>
         <Param val="receiver = new NeuracleDataClient({&quot;n_channel&quot;: config_info[&quot;channel_labels&quot;].length, &quot;samplerate&quot;: config_info[&quot;sample_rate&quot;], &quot;host&quot;: config_info[&quot;host&quot;], &quot;port&quot;: config_info[&quot;port&quot;]});&amp;#10;trigger = new TriggerNeuracle();&amp;#10;hand_device = new FuboPneumaticFingerClient({&quot;port&quot;: args.com});&amp;#10;controller = new Controller(0.0, args.model_path, {&quot;state_change_threshold&quot;: 0.8});&amp;#10;" valType="extendedCode" updates="constant" name="Before JS Experiment"/>
         <Param val="" valType="extendedCode" updates="constant" name="Begin Experiment"/>
         <Param val="" valType="extendedCode" updates="constant" name="Begin Experiment"/>
         <Param val="" valType="extendedCode" updates="constant" name="Begin JS Experiment"/>
         <Param val="" valType="extendedCode" updates="constant" name="Begin JS Experiment"/>

+ 4 - 3
backend/free_grasp.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 十一月 21, 2023, at 14:15
+    on Tue Nov 21 18:13:33 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -93,7 +93,8 @@ controller = Controller(0., args.model_path,
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
                                samplerate=config_info['sample_rate'],
                                samplerate=config_info['sample_rate'],
                                host=config_info['host'],
                                host=config_info['host'],
-                               port=config_info['port'])
+                               port=config_info['port'].
+                               buffer_len=config_info['buffer_length'])
 
 
 # connect to trigger box
 # connect to trigger box
 trigger = TriggerNeuracle()
 trigger = TriggerNeuracle()
@@ -175,7 +176,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\free_grasp.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/free_grasp.py',
         savePickle=True, saveWideText=True,
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )
     )

File diff suppressed because it is too large
+ 0 - 0
backend/general_grasp_training.psyexp


+ 4 - 3
backend/general_grasp_training.py

@@ -2,7 +2,7 @@
 # -*- coding: utf-8 -*-
 # -*- coding: utf-8 -*-
 """
 """
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
 This experiment was created using PsychoPy3 Experiment Builder (v2023.2.3),
-    on 十一月 21, 2023, at 13:15
+    on Tue Nov 21 18:12:16 2023
 If you publish work using this script the most relevant publication is:
 If you publish work using this script the most relevant publication is:
 
 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
     Peirce J, Gray JR, Simpson S, MacAskill M, Höchenberger R, Sogo H, Kastman E, Lindeløv JK. (2019) 
@@ -109,7 +109,8 @@ args = parse_args()
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
 receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
                                samplerate=config_info['sample_rate'],
                                samplerate=config_info['sample_rate'],
                                host=config_info['host'],
                                host=config_info['host'],
-                               port=config_info['port'])
+                               port=config_info['port']
+                               buffer_len=cofig_info['buffer_length'])
 
 
 # connect to trigger box
 # connect to trigger box
 trigger = TriggerNeuracle()
 trigger = TriggerNeuracle()
@@ -198,7 +199,7 @@ def setupData(expInfo, dataDir=None):
     thisExp = data.ExperimentHandler(
     thisExp = data.ExperimentHandler(
         name=expName, version='',
         name=expName, version='',
         extraInfo=expInfo, runtimeInfo=None,
         extraInfo=expInfo, runtimeInfo=None,
-        originPath='C:\\Users\\asena\\Desktop\\kraken\\backend\\general_grasp_training.py',
+        originPath='/Users/dingkunliu/Projects/MI-BCI-Proj/kraken/backend/general_grasp_training.py',
         savePickle=True, saveWideText=True,
         savePickle=True, saveWideText=True,
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
         dataFileName=dataDir + os.sep + filename, sortColumns='time'
     )
     )

+ 1 - 0
backend/settings/config.py

@@ -14,6 +14,7 @@ class Settings:
         'port': 8712,
         'port': 8712,
         'channel_count': 9,
         'channel_count': 9,
         'sample_rate': 1000,
         'sample_rate': 1000,
+        'buffer_length': 0.5,
         'channel_labels': [
         'channel_labels': [
             'CH001',
             'CH001',
             'CH002',
             'CH002',

+ 6 - 2
backend/training.py

@@ -13,9 +13,11 @@ import bci_core.feature_extractors as feature_extractors
 import bci_core.utils as bci_utils
 import bci_core.utils as bci_utils
 import bci_core.model as bci_model
 import bci_core.model as bci_model
 from dataloaders import neo
 from dataloaders import neo
+from settings.config import settings
 
 
 
 
 logging.basicConfig(level=logging.INFO)
 logging.basicConfig(level=logging.INFO)
+config_info = settings.CONFIG_INFO
 
 
 
 
 def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
 def train_model(raw, event_id, trial_duration=1., model_type='baseline'):
@@ -108,6 +110,7 @@ if __name__ == '__main__':
     subj_name = 'XW01'
     subj_name = 'XW01'
     model_type = 'riemann'
     model_type = 'riemann'
     # TODO: load subject config
     # TODO: load subject config
+    # include frequency band, model_type, upsampled_trial_duration
 
 
     data_dir = f'./data/{subj_name}/'
     data_dir = f'./data/{subj_name}/'
     model_dir = './static/models/'
     model_dir = './static/models/'
@@ -119,11 +122,12 @@ if __name__ == '__main__':
     for f in sessions.keys():
     for f in sessions.keys():
         event_id[f] = neo.FINGERMODEL_IDS[f]
         event_id[f] = neo.FINGERMODEL_IDS[f]
     
     
+    trial_duration = config_info['buffer_length']
     # preprocess raw
     # preprocess raw
-    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=1., ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
+    raw = neo.raw_preprocessing(data_dir, sessions, unify_label=True, upsampled_epoch_length=trial_duration, ori_epoch_length=5, mov_trial_ind=[2], rest_trial_ind=[1])
 
 
     # train model
     # train model
-    model = train_model(raw, event_id=event_id, model_type=model_type)
+    model = train_model(raw, event_id=event_id, model_type=model_type, trial_duration=trial_duration)
     
     
     # save
     # save
     model_saver(model, model_dir, model_type, subj_name, event_id)
     model_saver(model, model_dir, model_type, subj_name, event_id)

+ 13 - 7
backend/validation.py

@@ -13,20 +13,24 @@ from dataloaders import neo
 import bci_core.online as online
 import bci_core.online as online
 import bci_core.utils as bci_utils
 import bci_core.utils as bci_utils
 import bci_core.viz as bci_viz
 import bci_core.viz as bci_viz
+from settings.config import settings
 
 
 
 
-logging.basicConfig(level=logging.DEBUG)
+logging.basicConfig(level=logging.INFO)
+config_info = settings.CONFIG_INFO
 
 
 
 
 class DataGenerator:
 class DataGenerator:
-    def __init__(self, fs, X):
+    def __init__(self, fs, X, epoch_step=1.):
         self.fs = int(fs)
         self.fs = int(fs)
         self.X = X
         self.X = X
+        self.epoch_step = epoch_step
 
 
     def get_data_batch(self, current_index):
     def get_data_batch(self, current_index):
-        # return 1s batch
+        # return epoch_step length batch
         # create mne object
         # create mne object
-        data = self.X[:, current_index - self.fs:current_index].copy()
+        ind = int(self.epoch_step * self.fs)
+        data = self.X[:, current_index - ind:current_index].copy()
         return self.fs, [], data
         return self.fs, [], data
 
 
     def loop(self, step_size=0.1):
     def loop(self, step_size=0.1):
@@ -35,13 +39,14 @@ class DataGenerator:
             yield i / self.fs, self.get_data_batch(i)
             yield i / self.fs, self.get_data_batch(i)
 
 
 
 
-def validation(raw_val, event_id, model, state_change_threshold=0.8):
+def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
     """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
     """模型验证接口,使用指定数据进行训练+验证,绘制ersd map
     Args:
     Args:
         raw (mne.io.Raw)
         raw (mne.io.Raw)
         event_id (dict)
         event_id (dict)
         model: validate existing model, 
         model: validate existing model, 
         state_change_threshold (float): default 0.8
         state_change_threshold (float): default 0.8
+        step_length (float): batch data step length, default 1. (s)
 
 
     Returns:
     Returns:
         None
         None
@@ -65,7 +70,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8):
 
 
     # validate with the second half
     # validate with the second half
     val_data = raw_val.get_data()
     val_data = raw_val.get_data()
-    data_gen = DataGenerator(fs, val_data)
+    data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
     rets = []
     rets = []
     for time, data in data_gen.loop():
     for time, data in data_gen.loop():
         cls = controller.decision(data)
         cls = controller.decision(data)
@@ -120,7 +125,8 @@ if __name__ == '__main__':
     metrics, fig_erds, fig_pred = validation(raw, 
     metrics, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              event_id, 
                                              model=model_path, 
                                              model=model_path, 
-                                             state_change_threshold=0.75)
+                                             state_change_threshold=0.75,
+                                             step_length=config_info['buffer_length'])
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')
     logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')

Some files were not shown because too many files changed in this diff