|
@@ -76,22 +76,35 @@ def plot_time_frequency(data, events, sfreq, freqs, epoch_time_range, event_id):
|
|
|
|
|
|
|
|
|
|
def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
|
|
def plot_ersd(data, events, sfreq, epoch_time_range, event_id, rest_event=0):
|
|
|
|
+ n_ch = data.shape[0]
|
|
event_desc = {v:k for k, v in event_id.items()}
|
|
event_desc = {v:k for k, v in event_id.items()}
|
|
epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
|
|
epochs = cut_epochs((*epoch_time_range, sfreq), data, events[:, 0])
|
|
|
|
+
|
|
psd, freqs = psd_array_welch(epochs, sfreq, fmin=0, fmax=200)
|
|
psd, freqs = psd_array_welch(epochs, sfreq, fmin=0, fmax=200)
|
|
- mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=(0, 1))
|
|
|
|
|
|
+
|
|
|
|
+ mean_psd_rest = psd[events[:, -1] == rest_event].mean(axis=0)
|
|
ersds = []
|
|
ersds = []
|
|
for e in np.unique(events[:, -1]):
|
|
for e in np.unique(events[:, -1]):
|
|
if e != rest_event:
|
|
if e != rest_event:
|
|
- mean_psd = psd[events[:, -1] == e].mean(axis=(0, 1))
|
|
|
|
|
|
+ mean_psd = psd[events[:, -1] == e].mean(axis=0)
|
|
ersd = mean_psd / mean_psd_rest - 1
|
|
ersd = mean_psd / mean_psd_rest - 1
|
|
ersds.append((event_desc[e], ersd))
|
|
ersds.append((event_desc[e], ersd))
|
|
- fig, axes = plt.subplots(1, len(ersds), figsize=(5 * len(ersds), 5))
|
|
|
|
- if len(ersds) == 1:
|
|
|
|
- axes = [axes]
|
|
|
|
- for i, ersd in enumerate(ersds):
|
|
|
|
- axes[i].plot(freqs, ersd[1])
|
|
|
|
- axes[i].set_title(ersd[0])
|
|
|
|
- axes[i].set_ylabel('ERSD')
|
|
|
|
- axes[i].set_xlabel('Frequency [Hz]')
|
|
|
|
|
|
+ fig, axes = plt.subplots(n_ch, len(ersds), figsize=(3 * len(ersds), n_ch), sharex=True, sharey=True)
|
|
|
|
+
|
|
|
|
+ for i in range(n_ch):
|
|
|
|
+ if len(ersds) == 1:
|
|
|
|
+ for j, ersd in enumerate(ersds):
|
|
|
|
+ if i == 0:
|
|
|
|
+ axes[i].set_title(ersd[0])
|
|
|
|
+ axes[i].plot(freqs, ersd[1][i])
|
|
|
|
+ axes[i].set_ylabel(f'ch_{i + 1}')
|
|
|
|
+ axes[i].axhline(0, color='gray', linestyle='--')
|
|
|
|
+ else:
|
|
|
|
+ for j, ersd in enumerate(ersds):
|
|
|
|
+ if i == 0:
|
|
|
|
+ axes[i, j].set_title(ersd[0])
|
|
|
|
+ axes[i, j].plot(freqs, ersd[1][i])
|
|
|
|
+ axes[i, j].set_ylabel(f'ch_{i + 1}')
|
|
|
|
+ axes[i, j].axhline(0, color='gray', linestyle='--')
|
|
|
|
+ fig.suptitle('ERSD')
|
|
return fig
|
|
return fig
|