|
@@ -51,13 +51,11 @@ class TestOnlineSim(unittest.TestCase):
|
|
|
def test_sim(self):
|
|
|
metric_hmm, metric_nohmm, fig_pred = simulation(self.raw, self.event_id, model=self.model_path, state_change_threshold=0.7)
|
|
|
fig_pred.savefig('./tests/data/pred.pdf')
|
|
|
- print(metric_hmm, metric_nohmm)
|
|
|
self.assertTrue(metric_hmm[-2] > 0.3) # f1-score (with hmm)
|
|
|
self.assertTrue(metric_nohmm[-2] < 0.15) # f1-score (without hmm)
|
|
|
|
|
|
def test_val_model(self):
|
|
|
metrices, fig_conf = val_by_epochs(self.raw_val, self.model_path, self.event_id, 1.)
|
|
|
- print(metrices)
|
|
|
fig_conf.savefig('./tests/data/conf.pdf')
|
|
|
self.assertGreater(metrices[0], 0.85)
|
|
|
self.assertGreater(metrices[1], 0.7)
|