Browse Source

Feat: outlier filter

dk 1 year ago
parent
commit
686782ca79
4 changed files with 39 additions and 4 deletions
  1. 6 2
      backend/bci_core/pipeline.py
  2. 27 0
      backend/bci_core/utils.py
  3. 5 0
      backend/training.py
  4. 1 2
      backend/validation.py

+ 6 - 2
backend/bci_core/pipeline.py

@@ -2,7 +2,7 @@ import numpy as np
 
 from .model import riemann_feature_embedder, baseline_feature_embedder
 from .feature_extractors import FeatExtractor, FilterbankExtractor
-from .utils import cut_epochs
+from .utils import cut_epochs, events_filter
 
 
 def riemann_model_builder(fs, n_ch=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
@@ -29,6 +29,10 @@ def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, retu
     filtered_data = feat_extractor.transform(raw)
     if (events is not None) and (duration is not None):
         X = cut_epochs((0, duration, fs), filtered_data, events[:, 0])
+        #   
+        indices = events_filter(X, fs, events)
+        X = X[indices]
+        y_true = events[:, 2][indices]
     else:
         X = filtered_data[None]
     # embed feature
@@ -37,7 +41,7 @@ def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, retu
     prob = clf.predict_proba(X_embed)
     if return_cls:
         y_pred = clf.classes_[np.argmax(prob, axis=1)]
-        return prob, y_pred
+        return prob, y_pred, y_true
     else:
         return prob
 

+ 27 - 0
backend/bci_core/utils.py

@@ -1,4 +1,5 @@
 import numpy as np
+import mne
 import itertools
 from datetime import datetime
 import joblib
@@ -10,6 +11,32 @@ import os
 logger = logging.getLogger(__name__)
 
 
+def events_filter(epoch_data, fs, events):
+    """
+    按照hg能量去除异常数据
+    """
+    def exclude_outlier(epoch_data, lower=1.5, upper=1.5):
+        epoch_power = np.sqrt(np.mean(epoch_data ** 2, axis=(1, 2)))
+        mean_, std_ = epoch_power.mean(), epoch_power.std()
+        indices = (epoch_power < mean_ + upper * std_) & (epoch_power > mean_ - lower * std_)
+        return indices
+    
+    # extract hg
+    epoch_hg = mne.filter.filter_data(epoch_data, fs, 60, 90)
+
+    ind_cls = np.unique(events[:, 2])
+    mov_ind = [i for i in ind_cls if i != 0]
+    rest_ind = 0
+    epoch_rest = epoch_hg[events[:, 2] == rest_ind]
+    rest_indices = np.flatnonzero(events[:, 2] == rest_ind)
+    epoch_mov = epoch_hg[np.isin(events[:, 2], mov_ind)]
+    mov_indices = np.flatnonzero(np.isin(events[:, 2], mov_ind))
+    
+    rest_filtered_indices = rest_indices[exclude_outlier(epoch_rest, 1.5, 1.5)]
+    mov_filtered_indices = mov_indices[exclude_outlier(epoch_mov, 1., 2)]
+    return np.sort(np.concatenate((rest_filtered_indices, mov_filtered_indices)))
+
+
 def event_to_stim_channel(events, time_length, trial_length=None):
     x = np.zeros(time_length, dtype=np.int32)
     if trial_length is not None:

+ 5 - 0
backend/training.py

@@ -64,6 +64,11 @@ def _param_search(model, raw, duration, events):
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
     y = events[:, -1]
 
+    # filter out abnormal data
+    indices = bci_utils.events_filter(X, fs, events)
+    X = X[indices]
+    y = y[indices]
+
     # embed feature
     X_embed = embedder.fit_transform(X)
 

+ 1 - 2
backend/validation.py

@@ -53,10 +53,9 @@ def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
     events, _ = mne.events_from_annotations(raw, event_id=event_id)
     # parse model type
     models = joblib.load(model_path)
-    prob, y_pred = bci_pipeline.data_evaluation(models, raw.get_data(), raw.info['sfreq'], events, trial_duration, True)
+    prob, y_pred, y = bci_pipeline.data_evaluation(models, raw.get_data(), raw.info['sfreq'], events, trial_duration, True)
     
     # metrices: AUC, accuracy, 
-    y = events[:, -1]
     auc = bci_utils.multiclass_auc_score(y, prob)
     accu = accuracy_score(y, y_pred)
     f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')