Explorar el Código

Revert "Feat: outlier filter"

This reverts commit 686782ca799ed8e2a55fe6f1ca977f9d4f45862f.
dk hace 1 año
padre
commit
9d5d14d1f6
Se han modificado 4 ficheros con 4 adiciones y 39 borrados
  1. 2 6
      backend/bci_core/pipeline.py
  2. 0 27
      backend/bci_core/utils.py
  3. 0 5
      backend/training.py
  4. 2 1
      backend/validation.py

+ 2 - 6
backend/bci_core/pipeline.py

@@ -2,7 +2,7 @@ import numpy as np
 
 from .model import riemann_feature_embedder, baseline_feature_embedder
 from .feature_extractors import FeatExtractor, FilterbankExtractor
-from .utils import cut_epochs, events_filter
+from .utils import cut_epochs
 
 
 def riemann_model_builder(fs, n_ch=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
@@ -29,10 +29,6 @@ def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, retu
     filtered_data = feat_extractor.transform(raw)
     if (events is not None) and (duration is not None):
         X = cut_epochs((0, duration, fs), filtered_data, events[:, 0])
-        #   
-        indices = events_filter(X, fs, events)
-        X = X[indices]
-        y_true = events[:, 2][indices]
     else:
         X = filtered_data[None]
     # embed feature
@@ -41,7 +37,7 @@ def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, retu
     prob = clf.predict_proba(X_embed)
     if return_cls:
         y_pred = clf.classes_[np.argmax(prob, axis=1)]
-        return prob, y_pred, y_true
+        return prob, y_pred
     else:
         return prob
 

+ 0 - 27
backend/bci_core/utils.py

@@ -1,5 +1,4 @@
 import numpy as np
-import mne
 import itertools
 from datetime import datetime
 import joblib
@@ -11,32 +10,6 @@ import os
 logger = logging.getLogger(__name__)
 
 
-def events_filter(epoch_data, fs, events):
-    """
-    按照hg能量去除异常数据
-    """
-    def exclude_outlier(epoch_data, lower=1.5, upper=1.5):
-        epoch_power = np.sqrt(np.mean(epoch_data ** 2, axis=(1, 2)))
-        mean_, std_ = epoch_power.mean(), epoch_power.std()
-        indices = (epoch_power < mean_ + upper * std_) & (epoch_power > mean_ - lower * std_)
-        return indices
-    
-    # extract hg
-    epoch_hg = mne.filter.filter_data(epoch_data, fs, 60, 90)
-
-    ind_cls = np.unique(events[:, 2])
-    mov_ind = [i for i in ind_cls if i != 0]
-    rest_ind = 0
-    epoch_rest = epoch_hg[events[:, 2] == rest_ind]
-    rest_indices = np.flatnonzero(events[:, 2] == rest_ind)
-    epoch_mov = epoch_hg[np.isin(events[:, 2], mov_ind)]
-    mov_indices = np.flatnonzero(np.isin(events[:, 2], mov_ind))
-    
-    rest_filtered_indices = rest_indices[exclude_outlier(epoch_rest, 1.5, 1.5)]
-    mov_filtered_indices = mov_indices[exclude_outlier(epoch_mov, 1., 2)]
-    return np.sort(np.concatenate((rest_filtered_indices, mov_filtered_indices)))
-
-
 def event_to_stim_channel(events, time_length, trial_length=None):
     x = np.zeros(time_length, dtype=np.int32)
     if trial_length is not None:

+ 0 - 5
backend/training.py

@@ -64,11 +64,6 @@ def _param_search(model, raw, duration, events):
     X = bci_utils.cut_epochs((0, duration, fs), filtered_data, events[:, 0])
     y = events[:, -1]
 
-    # filter out abnormal data
-    indices = bci_utils.events_filter(X, fs, events)
-    X = X[indices]
-    y = y[indices]
-
     # embed feature
     X_embed = embedder.fit_transform(X)
 

+ 2 - 1
backend/validation.py

@@ -53,9 +53,10 @@ def val_by_epochs(raw, model_path, event_id, trial_duration=1., ):
     events, _ = mne.events_from_annotations(raw, event_id=event_id)
     # parse model type
     models = joblib.load(model_path)
-    prob, y_pred, y = bci_pipeline.data_evaluation(models, raw.get_data(), raw.info['sfreq'], events, trial_duration, True)
+    prob, y_pred = bci_pipeline.data_evaluation(models, raw.get_data(), raw.info['sfreq'], events, trial_duration, True)
     
     # metrices: AUC, accuracy, 
+    y = events[:, -1]
     auc = bci_utils.multiclass_auc_score(y, prob)
     accu = accuracy_score(y, y_pred)
     f1 = f1_score(y, y_pred, pos_label=np.max(y), average='macro')