|
@@ -19,14 +19,16 @@ logging.basicConfig(level=logging.DEBUG)
|
|
|
|
|
|
|
|
|
class DataGenerator:
|
|
|
- def __init__(self, fs, X):
|
|
|
+ def __init__(self, fs, X, epoch_step=1.):
|
|
|
self.fs = int(fs)
|
|
|
self.X = X
|
|
|
+ self.epoch_step = epoch_step
|
|
|
|
|
|
def get_data_batch(self, current_index):
|
|
|
- # return 1s batch
|
|
|
+ # return epoch_step length batch
|
|
|
# create mne object
|
|
|
- data = self.X[:, current_index - self.fs:current_index].copy()
|
|
|
+ ind = int(self.epoch_step * self.fs)
|
|
|
+ data = self.X[:, current_index - ind:current_index].copy()
|
|
|
return self.fs, [], data
|
|
|
|
|
|
def loop(self, step_size=0.1):
|
|
@@ -35,13 +37,14 @@ class DataGenerator:
|
|
|
yield i / self.fs, self.get_data_batch(i)
|
|
|
|
|
|
|
|
|
-def validation(raw_val, event_id, model, state_change_threshold=0.8):
|
|
|
+def validation(raw_val, event_id, model, state_change_threshold=0.8, step_length=1.):
|
|
|
"""模型验证接口,使用指定数据进行训练+验证,绘制ersd map
|
|
|
Args:
|
|
|
raw (mne.io.Raw)
|
|
|
event_id (dict)
|
|
|
model: validate existing model,
|
|
|
state_change_threshold (float): default 0.8
|
|
|
+ step_length (float): batch data step length, default 1. (s)
|
|
|
|
|
|
Returns:
|
|
|
None
|
|
@@ -65,7 +68,7 @@ def validation(raw_val, event_id, model, state_change_threshold=0.8):
|
|
|
|
|
|
# validate with the second half
|
|
|
val_data = raw_val.get_data()
|
|
|
- data_gen = DataGenerator(fs, val_data)
|
|
|
+ data_gen = DataGenerator(fs, val_data, epoch_step=step_length)
|
|
|
rets = []
|
|
|
for time, data in data_gen.loop():
|
|
|
cls = controller.decision(data)
|