viz.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import mne
  2. from mne.time_frequency import psd_array_welch
  3. import matplotlib.pyplot as plt
  4. import matplotlib as mpl
  5. import numpy as np
  6. from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
  7. from .utils import cut_epochs
  8. from .feature_extractors import filterbank_extractor
  9. def snapshot_brain(fig_3d, info, data=None, show_name=False):
  10. if data is not None:
  11. cmap = mpl.cm.viridis
  12. norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
  13. mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
  14. directions = [0, 180] # right, left
  15. figs = []
  16. for d in directions:
  17. # right
  18. mne.viz.set_3d_view(fig_3d, azimuth=d, elevation=70)
  19. xy, im = mne.viz.snapshot_brain_montage(fig_3d, info, hide_sensors=False)
  20. fig, ax = plt.subplots(figsize=(5, 5))
  21. ax.imshow(im, interpolation='none')
  22. ax.set_axis_off()
  23. if data is not None:
  24. fig.subplots_adjust(right=0.8)
  25. cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
  26. fig.colorbar(mappable, cax=cbar_ax)
  27. if show_name:
  28. xy_pts = np.vstack([xy[ch] for ch in info["ch_names"]])
  29. for i, pos in enumerate(xy_pts):
  30. ax.text(*pos, i, color='white')
  31. figs.append(fig)
  32. return figs
  33. def plot_time_frequency(data, events, sfreq, freqs, epoch_time_range, event_id):
  34. """
  35. data: numpy.ndarray, (n_ch, n_times)
  36. events: ndarray (n_events, 3), the first column is onset index, the second is duration, and the third is event type
  37. freqs: numpy.ndarray, frequency bands to filter
  38. epoch_time_range: tuple, (t_onset, t_offset)
  39. event_id: dict {id: name}
  40. """
  41. # extract power, (n_ch, n_freqs, n_times)
  42. power = filterbank_extractor(data, sfreq, freqs, reshape_freqs_dim=False)
  43. power = 10 * np.log10(power)
  44. # normalize by freqs
  45. power -= power.mean(axis=(0, 2), keepdims=True)
  46. power /= power.std(axis=(0, 2), keepdims=True)
  47. # image vlim
  48. mean_, std_ = power.mean(), power.std()
  49. # cut epochs
  50. epochs = cut_epochs((*epoch_time_range, sfreq), power, events[:, 0])
  51. # average by event type
  52. classes = np.unique(events[:, -1])
  53. fig, axes = plt.subplots(1, len(classes), figsize=(10, 5))
  54. for ax, y_ in zip(axes, classes):
  55. average_power = epochs[events[:, -1] == y_].mean(axis=(0, 1)) # keep freqencies and times
  56. im = ax.imshow(average_power, cmap='RdBu_r',
  57. vmin=mean_ - 0.5 * std_,
  58. vmax=mean_ + 0.5 * std_,
  59. aspect='auto',
  60. origin='lower')
  61. ax.set_xticks(np.linspace(-0.5, average_power.shape[1] - 0.5, 5))
  62. ax.set_xticklabels([f'{i:.2f}' for i in np.linspace(*epoch_time_range, 5)])
  63. ax.set_yticks(np.linspace(-0.5, average_power.shape[0] - 0.5, 10))
  64. ax.set_yticklabels([f'{int(i):3d}' for i in np.linspace(freqs[0], freqs[-1], 10)])
  65. ax.set_title(event_id[y_])
  66. fig.subplots_adjust(right=0.8)
  67. cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
  68. fig.colorbar(im, cax=cbar_ax)
  69. return fig
  70. def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
  71. n_ch = data.shape[0]
  72. event_desc = {v:k for k, v in event_id.items()}
  73. epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
  74. psd, freqs = psd_array_welch(epochs, sfreq, fmin=0, fmax=200)
  75. mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=0)
  76. ersds = []
  77. for e in np.unique(events[:, -1]):
  78. if e != rest_event:
  79. mean_psd = psd[events[:, -1] == e].mean(axis=0)
  80. ersd = mean_psd / mean_psd_rest - 1
  81. ersds.append((event_desc[e], ersd))
  82. fig, axes = plt.subplots(n_ch, len(ersds), figsize=(3 * len(ersds), n_ch), sharex=True, sharey=True)
  83. for i in range(n_ch):
  84. if len(ersds) == 1:
  85. for j, ersd in enumerate(ersds):
  86. if i == 0:
  87. axes[i].set_title(ersd[0])
  88. axes[i].plot(freqs, ersd[1][i])
  89. axes[i].set_ylabel(f'ch_{i + 1}')
  90. axes[i].axhline(0, color='gray', linestyle='--')
  91. else:
  92. for j, ersd in enumerate(ersds):
  93. if i == 0:
  94. axes[i, j].set_title(ersd[0])
  95. axes[i, j].plot(freqs, ersd[1][i])
  96. axes[i, j].set_ylabel(f'ch_{i + 1}')
  97. axes[i, j].axhline(0, color='gray', linestyle='--')
  98. fig.suptitle('ERSD')
  99. return fig
  100. def plot_confusion_matrix(y_true, y_pred):
  101. cm = confusion_matrix(y_true, y_pred)
  102. disp = ConfusionMatrixDisplay(cm)
  103. disp.plot()
  104. return disp.figure_