viz.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import mne
  2. from mne.time_frequency import psd_array_multitaper
  3. import matplotlib.pyplot as plt
  4. import matplotlib as mpl
  5. import numpy as np
  6. from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
  7. from .utils import cut_epochs
  8. from .feature_extractors import filterbank_extractor
  9. def plot_embeddings(embd, y, event_id, size=20, figsize=(4, 5), show_legend=True):
  10. fig, ax = plt.subplots(figsize=figsize)
  11. for label in event_id.keys():
  12. l = event_id[label]
  13. idx = y == l
  14. ax.scatter(embd[idx, 0], embd[idx, 1], s=size, label=label)
  15. ax.set_xlabel(r"PC1")
  16. ax.set_ylabel(r"PC2")
  17. if show_legend:
  18. ax.legend(frameon=False)
  19. ax.spines['top'].set_visible(False)
  20. ax.spines['right'].set_visible(False)
  21. return fig
  22. def snapshot_brain(fig_3d, info, data=None, show_name=False):
  23. if data is not None:
  24. cmap = mpl.cm.viridis
  25. norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
  26. mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
  27. directions = [0, 180] # right, left
  28. figs = []
  29. for d in directions:
  30. # right
  31. mne.viz.set_3d_view(fig_3d, azimuth=d, elevation=70)
  32. xy, im = mne.viz.snapshot_brain_montage(fig_3d, info, hide_sensors=False)
  33. fig, ax = plt.subplots(figsize=(5, 5))
  34. ax.imshow(im, interpolation='none')
  35. ax.set_axis_off()
  36. if data is not None:
  37. fig.subplots_adjust(right=0.8)
  38. cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
  39. fig.colorbar(mappable, cax=cbar_ax)
  40. if show_name:
  41. xy_pts = np.vstack([xy[ch] for ch in info["ch_names"]])
  42. for i, pos in enumerate(xy_pts):
  43. ax.text(*pos, i, color='white')
  44. figs.append(fig)
  45. return figs
  46. def plot_time_frequency(data, events, sfreq, freqs, epoch_time_range, event_desc):
  47. """
  48. data: numpy.ndarray, (n_ch, n_times)
  49. events: ndarray (n_events, 3), the first column is onset index, the second is duration, and the third is event type
  50. freqs: numpy.ndarray, frequency bands to filter
  51. epoch_time_range: tuple, (t_onset, t_offset)
  52. event_desc: dict {id: name}
  53. """
  54. # extract power, (n_ch, n_freqs, n_times)
  55. power = filterbank_extractor(data, sfreq, freqs, reshape_freqs_dim=False)
  56. power = 10 * np.log10(power)
  57. # normalize by freqs
  58. power -= power.mean(axis=(0, 2), keepdims=True)
  59. power /= power.std(axis=(0, 2), keepdims=True)
  60. # image vlim
  61. mean_, std_ = power.mean(), power.std()
  62. # cut epochs
  63. epochs = cut_epochs((*epoch_time_range, sfreq), power, events[:, 0])
  64. # average by event type
  65. classes = np.unique(events[:, -1])
  66. fig, axes = plt.subplots(1, len(classes), figsize=(10, 5))
  67. for ax, y_ in zip(axes, classes):
  68. average_power = epochs[events[:, -1] == y_].mean(axis=(0, 1)) # keep freqencies and times
  69. im = ax.imshow(average_power, cmap='RdBu_r',
  70. extent=[*epoch_time_range, freqs[0], freqs[-1]],
  71. vmin=mean_ - 0.5 * std_,
  72. vmax=mean_ + 0.5 * std_,
  73. aspect='auto',
  74. origin='lower')
  75. ax.axvline(0, color='k')
  76. ax.set_title(event_desc[y_])
  77. fig.subplots_adjust(right=0.8)
  78. cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
  79. fig.colorbar(im, cax=cbar_ax)
  80. return fig
  81. def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
  82. n_ch = data.shape[0]
  83. event_desc = {v:k for k, v in event_id.items()}
  84. epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
  85. psd, freqs = psd_array_multitaper(epochs, sfreq, fmin=0, fmax=200, bandwidth=15)
  86. mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=0)
  87. ersds = []
  88. for e in np.unique(events[:, -1]):
  89. if e != rest_event:
  90. mean_psd = psd[events[:, -1] == e].mean(axis=0)
  91. ersd = mean_psd / mean_psd_rest - 1
  92. ersds.append((event_desc[e], ersd))
  93. fig, axes = plt.subplots(n_ch, len(ersds), figsize=(3 * len(ersds), n_ch), sharex=True, sharey=True)
  94. for i in range(n_ch):
  95. if len(ersds) == 1:
  96. for j, ersd in enumerate(ersds):
  97. if i == 0:
  98. axes[i].set_title(ersd[0])
  99. axes[i].plot(freqs, ersd[1][i])
  100. axes[i].set_ylabel(f'ch_{i + 1}')
  101. axes[i].axhline(0, color='gray', linestyle='--')
  102. else:
  103. for j, ersd in enumerate(ersds):
  104. if i == 0:
  105. axes[i, j].set_title(ersd[0])
  106. axes[i, j].plot(freqs, ersd[1][i])
  107. axes[i, j].set_ylabel(f'ch_{i + 1}')
  108. axes[i, j].axhline(0, color='gray', linestyle='--')
  109. fig.suptitle('ERSD')
  110. return fig
  111. def plot_confusion_matrix(y_true, y_pred):
  112. cm = confusion_matrix(y_true, y_pred, normalize='true')
  113. disp = ConfusionMatrixDisplay(cm)
  114. disp.plot()
  115. return disp.figure_
  116. def plot_states(time_range, pred_states, ax, colors=None):
  117. classes = np.unique(pred_states)
  118. if colors is None:
  119. colors = [plt.get_cmap('tab10')(i)[:3] for i in range(len(classes))]
  120. for i, c in enumerate(classes):
  121. ax.fill_between(np.linspace(*time_range, len(pred_states)), 0, 1,
  122. where=(pred_states == c), alpha=0.6, color=colors[i])
  123. return ax
  124. def plot_state_prob_with_cue(time_range, true_states, pred_probs, ax, colors=None):
  125. # normalize
  126. ax.plot(np.linspace(*time_range, len(pred_probs)), pred_probs, color='k')
  127. # for each class, fill different colors
  128. classes = np.unique(true_states)
  129. if colors is None:
  130. colors = [plt.get_cmap('tab10')(i)[:3] for i in range(len(classes))]
  131. for i, c in enumerate(classes):
  132. ax.fill_between(np.linspace(*time_range, len(true_states)), 0, 1,
  133. where=(true_states == c), alpha=0.6, color=colors[i])
  134. return ax