Browse Source

Feat: online filter

dk 1 year ago
parent
commit
626ab4c8ae
2 changed files with 58 additions and 19 deletions
  1. 48 16
      backend/device/data_client.py
  2. 10 3
      backend/tests/test_neo.py

+ 48 - 16
backend/device/data_client.py

@@ -2,19 +2,18 @@ import socket
 import threading
 
 import numpy as np
+from scipy import signal
 
 
 class NeuracleDataClient:
     UPDATE_INTERVAL = 0.04
     BYTES_PER_NUM = 4
-    BUFFER_LEN = 1  # in secondes
 
-    def __init__(self, n_channel=9, samplerate=1000, host='localhost', port=8712):
+    def __init__(self, n_channel=9, samplerate=1000, host='localhost', port=8712, buffer_len=1):
         self.n_channel = n_channel
-        self.chunk_size = int(self.UPDATE_INTERVAL * samplerate * self.BYTES_PER_NUM * n_channel)
         self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
         self.buffer = []
-        self.max_buffer_length = int(self.BUFFER_LEN / self.UPDATE_INTERVAL)
+        self.max_buffer_length = int(buffer_len * samplerate)
         self._host = host
         self._port = port
         # thread lock
@@ -22,6 +21,8 @@ class NeuracleDataClient:
         self.__datathread = threading.Thread(target=self.__recv_loop)
         self.samplerate = samplerate
 
+        self.filter = OnlineHPFilter(1, samplerate)
+
         # start client
         self.__config()
 
@@ -37,19 +38,37 @@ class NeuracleDataClient:
         self.__datathread.join()
 
     def __recv_loop(self):
-        while self.__sock.fileno() != -1:
+        while self.is_active():
             try:
                 data = self.__sock.recv(self.chunk_size)
             except OSError:
                 break
             if len(data) % 4 != 0:
                 continue
+
+            # unpack data
+            data = self._unpack_data()
+
+            # do highpass (exclude stim channel)
+            data[:, :-1] = self.filter.filter_incoming(data[:, :-1])
+
+            # update buffer
             self.lock.acquire()
-            self.buffer.append(data)
+            self.buffer.extend(data.tolist())
             # remove old data
-            if len(self.buffer) > self.max_buffer_length:
-                del self.buffer[0]
+            old_data_len = len(self.buffer) - self.max_buffer_length
+            if old_data_len > 0:
+                self.buffer = self.buffer[old_data_len:]
             self.lock.release()
+    
+    def _unpack_data(self, bytes_data):
+        total_data = b''.join(bytes_data)
+        byte_data = bytearray(total_data)
+        if len(byte_data) % 4 != 0:
+            raise ValueError
+        data = np.frombuffer(byte_data, dtype='<f')
+        data = np.reshape(data, (-1, self.n_channel))
+        return data
 
     def __run_forever(self):
         self.__datathread.start()
@@ -65,15 +84,9 @@ class NeuracleDataClient:
             data: ndarray with shape of (channels, timesteps)
         """
         self.lock.acquire()
-        raw_data = self.buffer.copy()
+        data = self.buffer.copy()
         self.lock.release()
-        total_data = b''.join(raw_data)
-        byte_data = bytearray(total_data)
-        if len(byte_data) % 4 != 0:
-            raise ValueError
-        data = np.frombuffer(byte_data, dtype='<f')
-        data = np.reshape(data, (-1, self.n_channel))
-
+        data = np.array(data)
         trigger_channel = data[:, -1]
         onset = np.flatnonzero(trigger_channel)
         event_label = trigger_channel[onset]
@@ -81,3 +94,22 @@ class NeuracleDataClient:
         if clear:
             self.buffer.clear()
         return self.samplerate, events, data[:, :-1].T
+
+
+class OnlineHPFilter:
+    def __init__(self, freq=1, fs=1000):
+        self.sos = signal.butter(4, freq, btype='hp', fs=fs, output='sos')
+        self._z = None
+
+    def filter_incoming(self, data):
+        """
+        Args: 
+            data (ndarray): (n_times, n_chs)
+        Returns:
+            y (ndarray): (n_times, n_chs)
+        """
+        if self._z is None:
+            self._z = np.zeros((self.sos.shape[0], 2, data.shape[1]))
+
+        y, self._z = signal.sosfilt(self.sos, data, axis=0, zi=self._z)
+        return y

+ 10 - 3
backend/tests/test_neo.py

@@ -10,8 +10,10 @@ class TestNeo(unittest.TestCase):
     @classmethod
     def setUpClass(cls):
         config_info = settings.CONFIG_INFO
+        cls.buffer_len = 1
         cls.receiver = NeuracleDataClient(n_channel=len(config_info['channel_labels']), 
-                                          samplerate=config_info['sample_rate'])
+                                          samplerate=config_info['sample_rate'],
+                                          buffer_len=cls.buffer_len)
         cls.trigger = TriggerNeuracle()
     
     @classmethod
@@ -25,10 +27,15 @@ class TestNeo(unittest.TestCase):
     def test_get_data(self):
         time.sleep(1)
         fs, event, data = self.receiver.get_trial_data(clear=True)
-        self.assertTrue(data.shape[1] == settings.CONFIG_INFO['sample_rate'] * self.receiver.BUFFER_LEN)
+        self.assertTrue(data.shape[1] == settings.CONFIG_INFO['sample_rate'] * self.buffer_len)
         self.assertTrue(data.shape[0] == len(settings.CONFIG_INFO['channel_labels']) - 1)
         self.assertTrue(event.size == 0)
     
+    def test_highpass(self):
+        time.sleep(1)
+        fs, event, data = self.receiver.get_trial_data(clear=True)
+        self.assertTrue(np.mean(data) < 1e-5)
+    
     def test_send_trigger_and_receive(self):
         time.sleep(1)
         for i in range(5):
@@ -38,7 +45,7 @@ class TestNeo(unittest.TestCase):
         print(event.shape)
         self.assertTrue(event.shape[0] == 5)
         self.assertTrue(np.allclose(event[:, 2], np.arange(1, 6)))
-        self.assertTrue(data.shape[1] == settings.CONFIG_INFO['sample_rate'] * self.receiver.BUFFER_LEN)
+        self.assertTrue(data.shape[1] == settings.CONFIG_INFO['sample_rate'] * self.buffer_len)