12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import mne
- from mne.time_frequency import psd_array_welch
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- import numpy as np
- from .utils import cut_epochs
- from .feature_extractors import filterbank_extractor
- def snapshot_brain(fig_3d, info, data=None, show_name=False):
- if data is not None:
- cmap = mpl.cm.viridis
- norm = mpl.colors.Normalize(vmin=data.min(), vmax=data.max())
- mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
- directions = [0, 180] # right, left
- figs = []
- for d in directions:
- # right
- mne.viz.set_3d_view(fig_3d, azimuth=d, elevation=70)
- xy, im = mne.viz.snapshot_brain_montage(fig_3d, info, hide_sensors=False)
- fig, ax = plt.subplots(figsize=(5, 5))
- ax.imshow(im, interpolation='none')
- ax.set_axis_off()
- if data is not None:
- fig.subplots_adjust(right=0.8)
- cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
- fig.colorbar(mappable, cax=cbar_ax)
- if show_name:
- xy_pts = np.vstack([xy[ch] for ch in info["ch_names"]])
- for i, pos in enumerate(xy_pts):
- ax.text(*pos, i, color='white')
- figs.append(fig)
- return figs
- def plot_time_frequency(data, events, sfreq, freqs, epoch_time_range, event_id):
- """
- data: numpy.ndarray, (n_ch, n_times)
- events: ndarray (n_events, 3), the first column is onset index, the second is duration, and the third is event type
- freqs: numpy.ndarray, frequency bands to filter
- epoch_time_range: tuple, (t_onset, t_offset)
- event_id: dict {id: name}
- """
- # extract power, (n_ch, n_freqs, n_times)
- power = filterbank_extractor(data, sfreq, freqs, reshape_freqs_dim=False)
- power = 10 * np.log10(power)
- # normalize by freqs
- power -= power.mean(axis=(0, 2), keepdims=True)
- power /= power.std(axis=(0, 2), keepdims=True)
- # image vlim
- mean_, std_ = power.mean(), power.std()
- # cut epochs
- epochs = cut_epochs((*epoch_time_range, sfreq), power, events[:, 0])
- # average by event type
- classes = np.unique(events[:, -1])
- fig, axes = plt.subplots(1, len(classes), figsize=(10, 5))
- for ax, y_ in zip(axes, classes):
- average_power = epochs[events[:, -1] == y_].mean(axis=(0, 1)) # keep freqencies and times
- im = ax.imshow(average_power, cmap='RdBu_r',
- vmin=mean_ - 0.5 * std_,
- vmax=mean_ + 0.5 * std_,
- aspect='auto',
- origin='lower')
- ax.set_xticks(np.linspace(-0.5, average_power.shape[1] - 0.5, 5))
- ax.set_xticklabels([f'{i:.2f}' for i in np.linspace(*epoch_time_range, 5)])
- ax.set_yticks(np.linspace(-0.5, average_power.shape[0] - 0.5, 10))
- ax.set_yticklabels([f'{int(i):3d}' for i in np.linspace(freqs[0], freqs[-1], 10)])
- ax.set_title(event_id[y_])
- fig.subplots_adjust(right=0.8)
- cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
- fig.colorbar(im, cax=cbar_ax)
- return fig
- def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
- event_desc = {v:k for k, v in event_id.items()}
- epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
- psd, freqs = psd_array_welch(epochs, sfreq)
- mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=(0, 1))
- ersds = []
- for e in np.unique(events[:, -1]):
- if e != rest_event:
- mean_psd = psd[events[:, -1] == e].mean(axis=(0, 1))
- ersd = mean_psd / mean_psd_rest - 1
- ersds.append((event_desc[e], ersd))
- fig, axes = plt.subplots(1, len(ersds), figsize=(5 * len(ersds), 5))
- if len(ersds) == 1:
- axes = [axes]
- for i, ersd in enumerate(ersds):
- axes[i].plot(freqs, ersd[1])
- axes[i].set_title(ersd[0])
- axes[i].set_ylabel('ERSD')
- axes[i].set_xlabel('Frequency [Hz]')
- return fig
|