12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- import numpy as np
- from .model import riemann_feature_embedder, baseline_feature_embedder, cps_feature_embedder
- from .feature_extractors import FeatExtractor, FilterbankExtractor
- from .utils import cut_epochs
- def csp_model_builder(fs, n_components=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
- feat_extractor = FeatExtractor(fs, lf_bands, hg_bands)
- embedder = cps_feature_embedder(n_components)
- return [feat_extractor, embedder]
- def riemann_model_builder(fs, n_ch=8, lf_bands=[(15, 35), (35, 50)], hg_bands=[(55, 95), (105, 145)]):
- feat_extractor = FeatExtractor(fs, lf_bands, hg_bands)
- # compute covariance
- feat_dim = []
- if lf_bands is not None:
- feat_dim.append(len(lf_bands) * n_ch)
- if hg_bands is not None:
- feat_dim.append(len(hg_bands) * n_ch)
- embedder = riemann_feature_embedder(feat_dim, estimator='lwf')
- return [feat_extractor, embedder]
- def baseline_model_builder(fs, freqs=(20, 150, 15), target_fs=10):
- filter_banks = np.arange(*freqs)
- feat_extractor = FilterbankExtractor(fs, filter_banks)
- embedder = baseline_feature_embedder(fs, target_fs, axis=-1)
- return [feat_extractor, embedder]
- def data_evaluation(model, raw: np.ndarray, fs, events=None, duration=None, return_cls=True):
- feat_extractor, embedder, clf = model
- filtered_data = feat_extractor.transform(raw)
- if (events is not None) and (duration is not None):
- X = cut_epochs((0, duration, fs), filtered_data, events[:, 0])
- else:
- X = filtered_data[None]
- # embed feature
- X_embed = embedder.transform(X)
- # pred
- prob = clf.predict_proba(X_embed)
- if return_cls:
- y_pred = clf.classes_[np.argmax(prob, axis=1)]
- return prob, y_pred
- else:
- return prob
|