1
0

viz.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import mne
  2. from mne.time_frequency import psd_array_welch
  3. import matplotlib.pyplot as plt
  4. import matplotlib as mpl
  5. import numpy as np
  6. from .utils import cut_epochs
  7. from .feature_extractors import filterbank_extractor
  8. def snapshot_brain(fig_3d, info, data=None, show_name=False):
  9. if data is not None:
  10. cmap = mpl.cm.viridis
  11. norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
  12. mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
  13. directions = [0, 180] # right, left
  14. figs = []
  15. for d in directions:
  16. # right
  17. mne.viz.set_3d_view(fig_3d, azimuth=d, elevation=70)
  18. xy, im = mne.viz.snapshot_brain_montage(fig_3d, info, hide_sensors=False)
  19. fig, ax = plt.subplots(figsize=(5, 5))
  20. ax.imshow(im, interpolation='none')
  21. ax.set_axis_off()
  22. if data is not None:
  23. fig.subplots_adjust(right=0.8)
  24. cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
  25. fig.colorbar(mappable, cax=cbar_ax)
  26. if show_name:
  27. xy_pts = np.vstack([xy[ch] for ch in info["ch_names"]])
  28. for i, pos in enumerate(xy_pts):
  29. ax.text(*pos, i, color='white')
  30. figs.append(fig)
  31. return figs
  32. def plot_time_frequency(data, events, sfreq, freqs, epoch_time_range, event_id):
  33. """
  34. data: numpy.ndarray, (n_ch, n_times)
  35. events: ndarray (n_events, 3), the first column is onset index, the second is duration, and the third is event type
  36. freqs: numpy.ndarray, frequency bands to filter
  37. epoch_time_range: tuple, (t_onset, t_offset)
  38. event_id: dict {id: name}
  39. """
  40. # extract power, (n_ch, n_freqs, n_times)
  41. power = filterbank_extractor(data, sfreq, freqs, reshape_freqs_dim=False)
  42. power = 10 * np.log10(power)
  43. # normalize by freqs
  44. power -= power.mean(axis=(0, 2), keepdims=True)
  45. power /= power.std(axis=(0, 2), keepdims=True)
  46. # image vlim
  47. mean_, std_ = power.mean(), power.std()
  48. # cut epochs
  49. epochs = cut_epochs((*epoch_time_range, sfreq), power, events[:, 0])
  50. # average by event type
  51. classes = np.unique(events[:, -1])
  52. fig, axes = plt.subplots(1, len(classes), figsize=(10, 5))
  53. for ax, y_ in zip(axes, classes):
  54. average_power = epochs[events[:, -1] == y_].mean(axis=(0, 1)) # keep freqencies and times
  55. im = ax.imshow(average_power, cmap='RdBu_r',
  56. vmin=mean_ - 0.5 * std_,
  57. vmax=mean_ + 0.5 * std_,
  58. aspect='auto',
  59. origin='lower')
  60. ax.set_xticks(np.linspace(-0.5, average_power.shape[1] - 0.5, 5))
  61. ax.set_xticklabels([f'{i:.2f}' for i in np.linspace(*epoch_time_range, 5)])
  62. ax.set_yticks(np.linspace(-0.5, average_power.shape[0] - 0.5, 10))
  63. ax.set_yticklabels([f'{int(i):3d}' for i in np.linspace(freqs[0], freqs[-1], 10)])
  64. ax.set_title(event_id[y_])
  65. fig.subplots_adjust(right=0.8)
  66. cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
  67. fig.colorbar(im, cax=cbar_ax)
  68. return fig
  69. def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
  70. event_desc = {v:k for k, v in event_id.items()}
  71. epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
  72. psd, freqs = psd_array_welch(epochs, sfreq)
  73. mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=(0, 1))
  74. ersds = []
  75. for e in np.unique(events[:, -1]):
  76. if e != rest_event:
  77. mean_psd = psd[events[:, -1] == e].mean(axis=(0, 1))
  78. ersd = mean_psd / mean_psd_rest - 1
  79. ersds.append((event_desc[e], ersd))
  80. fig, axes = plt.subplots(1, len(ersds), figsize=(5 * len(ersds), 5))
  81. if len(ersds) == 1:
  82. axes = [axes]
  83. for i, ersd in enumerate(ersds):
  84. axes[i].plot(freqs, ersd[1])
  85. axes[i].set_title(ersd[0])
  86. axes[i].set_ylabel('ERSD')
  87. axes[i].set_xlabel('Frequency [Hz]')
  88. return fig