pipeline.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import numpy as np
  2. from .model import riemann_feature_embedder, baseline_feature_embedder, cps_feature_embedder
  3. from .feature_extractors import FeatExtractor, FilterbankExtractor
  4. from .utils import cut_epochs
  5. def csp_model_builder(fs, n_components=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
  6. feat_extractor = FeatExtractor(fs, lf_bands, hg_bands)
  7. embedder = cps_feature_embedder(n_components)
  8. return [feat_extractor, embedder]
  9. def riemann_model_builder(fs, n_ch=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
  10. feat_extractor = FeatExtractor(fs, lf_bands, hg_bands)
  11. # compute covariance
  12. feat_dim = []
  13. if lf_bands is not None:
  14. feat_dim.append(len(lf_bands) * n_ch)
  15. if hg_bands is not None:
  16. feat_dim.append(len(hg_bands) * n_ch)
  17. embedder = riemann_feature_embedder(feat_dim, estimator='lwf')
  18. return [feat_extractor, embedder]
  19. def baseline_model_builder(fs, freqs=(20, 150, 15), target_fs=10):
  20. filter_banks = np.arange(*freqs)
  21. feat_extractor = FilterbankExtractor(fs, filter_banks)
  22. embedder = baseline_feature_embedder(fs, target_fs, axis=-1)
  23. return [feat_extractor, embedder]
  24. def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, return_cls=True):
  25. feat_extractor, embedder, clf = model
  26. filtered_data = feat_extractor.transform(raw)
  27. if (events is not None) and (duration is not None):
  28. X = cut_epochs((0, duration, fs), filtered_data, events[:, 0])
  29. else:
  30. X = filtered_data[None]
  31. # embed feature
  32. X_embed = embedder.transform(X)
  33. # pred
  34. prob = clf.predict_proba(X_embed)
  35. if return_cls:
  36. y_pred = clf.classes_[np.argmax(prob, axis=1)]
  37. return prob, y_pred
  38. else:
  39. return prob