pipeline.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import numpy as np
  2. from .model import riemann_feature_embedder, baseline_feature_embedder
  3. from .feature_extractors import FeatExtractor, FilterbankExtractor
  4. from .utils import cut_epochs
  5. def riemann_model_builder(fs, n_ch=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
  6. feat_extractor = FeatExtractor(fs, lf_bands, hg_bands)
  7. # compute covariance
  8. feat_dim = []
  9. if lf_bands is not None:
  10. feat_dim.append(len(lf_bands) * n_ch)
  11. if hg_bands is not None:
  12. feat_dim.append(len(hg_bands) * n_ch)
  13. embedder = riemann_feature_embedder(feat_dim, estimator='lwf')
  14. return [feat_extractor, embedder]
  15. def baseline_model_builder(fs, freqs=(20, 150, 15), target_fs=10):
  16. filter_banks = np.arange(*freqs)
  17. feat_extractor = FilterbankExtractor(fs, filter_banks)
  18. embedder = baseline_feature_embedder(fs, target_fs, axis=-1)
  19. return [feat_extractor, embedder]
  20. def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, return_cls=True):
  21. feat_extractor, embedder, clf = model
  22. filtered_data = feat_extractor.transform(raw)
  23. if (events is not None) and (duration is not None):
  24. X = cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  25. else:
  26. X = filtered_data[None]
  27. # embed feature
  28. X_embed = embedder.transform(X)
  29. # pred
  30. prob = clf.predict_proba(X_embed)
  31. if return_cls:
  32. y_pred = clf.classes_[np.argmax(prob, axis=1)]
  33. return prob, y_pred
  34. else:
  35. return prob