Ver Fonte

Feat: plot hg line

dk há 9 meses atrás
pai
commit
07a414a996
3 ficheiros alterados com 47 adições e 1 exclusões
  1. 17 0
      backend/bci_core/utils.py
  2. 24 1
      backend/bci_core/viz.py
  3. 6 0
      backend/signal_visualization.py

+ 17 - 0
backend/bci_core/utils.py

@@ -4,6 +4,7 @@ from datetime import datetime
 import joblib
 from sklearn.model_selection import KFold
 from sklearn.metrics import roc_auc_score
+from mne import baseline
 import logging
 import os
 
@@ -130,6 +131,22 @@ def cut_epochs(t, data, timestamps):
     return epochs
 
 
+def apply_baseline(t, data, mode='mean'):
+    """
+    Simple wrapper of mne rescale function
+    :param t: tuple (start, end, samplerate)
+    :param data: ndarray of any shape with axis=-1 the time axis
+    :param mode: 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio'
+        refer to mne.baseline.rescale
+    :return: ndarray
+    """
+    start, end, samplerate = t
+    base = (start, 0)
+    times = np.linspace(start, end, data.shape[-1])
+    data = baseline.rescale(data, times, baseline=base, mode=mode, verbose=False)
+    return data
+
+
 def product_dict(**kwargs):
     keys = kwargs.keys()
     vals = kwargs.values()

+ 24 - 1
backend/bci_core/viz.py

@@ -5,8 +5,9 @@ import matplotlib as mpl
 import numpy as np
 from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
 from scipy.ndimage import gaussian_filter
+from scipy import signal
 
-from .utils import cut_epochs
+from .utils import cut_epochs, apply_baseline
 from .feature_extractors import filterbank_extractor
 
 
@@ -30,6 +31,28 @@ def plot_embeddings(embd, y, event_id, size=20, figsize=(4, 5), show_legend=True
     return fig
 
 
+def plot_hg_envelope(raw, events, event_id, fs, freqs, tmin, tmax, t_smooth=0.3, target_event='flex'):
+    power = raw.filter(*freqs).apply_hilbert(envelope=True).get_data()
+    # moving average
+    n_smooth = int(t_smooth * fs)
+    # smooth
+    power = signal.filtfilt(np.ones(n_smooth) / n_smooth, 1, power, axis=-1)
+    epochs_hg = cut_epochs((tmin, tmax, fs), power, events[:, 0])
+    epochs_hg = apply_baseline((tmin, tmax, fs), epochs_hg)
+    
+    times = np.linspace(tmin, tmax, epochs_hg.shape[-1])
+    move_average = epochs_hg[events[:, 2] == event_id[target_event]].mean(axis=0)
+    move_se = epochs_hg[events[:, 2] == event_id[target_event]].std(axis=0) / np.sqrt(np.sum(events[:, 2] == event_id[target_event]))
+    n_ch = power.shape[0]
+    fig, axes = plt.subplots(1, 1)
+    for i in range(n_ch):
+        axes.plot(times, move_average[i], label=f'ch_{i + 1}')
+        axes.fill_between(times, move_average[i] - move_se[i], move_average[i] + move_se[i], alpha=0.1)
+    axes.legend()
+    fig.suptitle(f'HG line plot ({freqs[0], freqs[1]}))')
+    return fig
+
+
 def snapshot_brain(fig_3d, info, data=None, show_name=False):
     if data is not None:
         cmap = mpl.cm.viridis

+ 6 - 0
backend/signal_visualization.py

@@ -66,6 +66,12 @@ if __name__ == '__main__':
     # plot raw tfr
     fig_tfr_raw = bci_viz.plot_raw_tfr(raw.get_data(), fs, np.arange(5, 200, 10), n_cycles=20)
 
+    # hg average line plot
+    fig_hgs = {}
+    for t in sessions.keys():
+        fig_hg_1 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (55, 95), -1, 5, target_event=t)
+        fig_hg_2 = bci_viz.plot_hg_envelope(raw, events, event_id, fs, (95, 155), -1, 5, t_smooth=0.6, target_event=t)
+        fig_hgs[t] = (fig_hg_1, fig_hg_2)
     
     fig_erds.savefig(os.path.join(data_dir, 'erds.pdf'))
     fig_tfr.savefig(os.path.join(data_dir, 'tfr.pdf'))