Browse Source

Feat: add common average for neo data

dk 1 year ago
parent
commit
04e0f218b0
4 changed files with 18 additions and 14 deletions
  1. 7 1
      backend/bci_core/online.py
  2. 7 8
      backend/dataloaders/neo.py
  3. 1 1
      backend/training.py
  4. 3 4
      backend/validation.py

+ 7 - 1
backend/bci_core/online.py

@@ -3,6 +3,7 @@ import numpy as np
 import random
 import logging
 from scipy import signal
+import mne
 from .feature_extractors import filterbank_extractor
 from .utils import parse_model_type
 
@@ -132,7 +133,10 @@ class HMMModel:
         self._probability[current_state] = 1
     
     def step_probability(self, fs, data):
-        raise NotImplementedError
+        # do preprocessing here
+        # common average
+        data -= data.mean(axis=0)
+        return data
     
     def parse_data(self, data):
         fs, event, data_array = data
@@ -181,6 +185,7 @@ class BaselineHMM(HMMModel):
     def step_probability(self, fs, data):
         """Step 
         """
+        data = super(BaselineHMM, self).step_probability(fs, data)
         # filter data
         filter_bank_data = filterbank_extractor(data, fs, self.freqs, reshape_freqs_dim=True)
         # downsampling
@@ -203,6 +208,7 @@ class RiemannHMM(HMMModel):
     def step_probability(self, fs, data):
         """Step 
         """
+        data = super(RiemannHMM, self).step_probability(fs, data)
         data = self.feat_extractor.transform(data)
         # scale data
         data = self.scaler.transform(data)

+ 7 - 8
backend/dataloaders/neo.py

@@ -4,17 +4,14 @@ import json
 import mne
 import glob
 import pyedflib
+from scipy import signal
 from .utils import upsample_events
+from settings.config import settings
 
 
-FINGERMODEL_IDS = {
-    'rest': 0,
-    'cylinder': 1,
-    'ball': 2,
-    'flex': 3,
-    'double': 4,
-    'treble': 5
-}
+FINGERMODEL_IDS = settings.FINGERMODEL_IDS
+
+CONFIG_INFO = settings.CONFIG_INFO
 
 
 def raw_preprocessing(data_root, session_paths:dict, 
@@ -64,6 +61,8 @@ def raw_preprocessing(data_root, session_paths:dict,
     raws = mne.concatenate_raws(raws)
     raws.load_data()
 
+    # common average
+    raws.set_eeg_reference('average')
     # high pass
     raws = raws.filter(1, None)
     # filter 50Hz

+ 1 - 1
backend/training.py

@@ -108,7 +108,7 @@ def model_saver(model, model_path, model_type, subject_id, event_id):
 if __name__ == '__main__':
     # TODO: argparse
     subj_name = 'ylj'
-    model_type = 'riemann'
+    model_type = 'baseline'
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/train/'

+ 3 - 4
backend/validation.py

@@ -15,7 +15,7 @@ import bci_core.utils as bci_utils
 import bci_core.viz as bci_viz
 
 
-logging.basicConfig(level=logging.INFO)
+logging.basicConfig(level=logging.DEBUG)
 
 
 class DataGenerator:
@@ -102,11 +102,10 @@ def _event_to_stim_channel(events, time_length):
 if __name__ == '__main__':
     # TODO: argparse
     subj_name = 'ylj'
-    model_type = 'baseline'
     # TODO: load subject config
 
     data_dir = f'./data/{subj_name}/train/'
-    model_path = f'./static/models/{subj_name}/baseline_rest+cylinder_11-15-2023-21-34-41.pkl'
+    model_path = f'./static/models/{subj_name}/baseline_rest+cylinder_11-16-2023-16-38-32.pkl'
 
     with open(os.path.join(data_dir, 'info.yml'), 'r') as f:
         info = yaml.safe_load(f)
@@ -122,7 +121,7 @@ if __name__ == '__main__':
     metrics, fig_erds, fig_pred = validation(raw, 
                                              event_id, 
                                              model=model_path, 
-                                             state_change_threshold=0.95)
+                                             state_change_threshold=0.8)
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_pred.savefig(os.path.join(data_dir, 'pred.pdf'))   
     logging.info(f'precision: {metrics[0]:.4f}, recall: {metrics[1]:.4f}, f_beta_score: {metrics[2]:.4f}, corr: {metrics[3]:.4f}')