data_client.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import socket
  2. import threading
  3. import numpy as np
  4. from scipy import signal
  5. class NeuracleDataClient:
  6. UPDATE_INTERVAL = 0.04
  7. BYTES_PER_NUM = 4
  8. def __init__(self, n_channel=9, samplerate=1000, host='localhost', port=8712, buffer_len=1.):
  9. self.n_channel = n_channel
  10. self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
  11. self.chunk_size = int(self.UPDATE_INTERVAL * samplerate * self.BYTES_PER_NUM * n_channel)
  12. self.buffer = []
  13. self.max_buffer_length = int(buffer_len * samplerate)
  14. self._host = host
  15. self._port = port
  16. # thread lock
  17. self.lock = threading.Lock()
  18. self.__datathread = threading.Thread(target=self.__recv_loop)
  19. self.samplerate = samplerate
  20. self.filter = OnlineHPFilter(1, samplerate)
  21. # start client
  22. self.__config()
  23. def __config(self):
  24. self.__sock.connect((self._host, self._port))
  25. self.__run_forever()
  26. def is_active(self):
  27. return self.__sock.fileno() != -1
  28. def close(self):
  29. self.__sock.close()
  30. self.__datathread.join()
  31. def __recv_loop(self):
  32. while self.is_active():
  33. try:
  34. data = self.__sock.recv(self.chunk_size)
  35. except OSError:
  36. break
  37. if len(data) % 4 != 0:
  38. continue
  39. # unpack data
  40. data = self._unpack_data(data)
  41. # do highpass (exclude stim channel)
  42. data[:, :-1] = self.filter.filter_incoming(data[:, :-1])
  43. # update buffer
  44. self.lock.acquire()
  45. self.buffer.extend(data.tolist())
  46. # remove old data
  47. old_data_len = len(self.buffer) - self.max_buffer_length
  48. if old_data_len > 0:
  49. self.buffer = self.buffer[old_data_len:]
  50. self.lock.release()
  51. def _unpack_data(self, bytes_data):
  52. byte_data = bytearray(bytes_data)
  53. if len(byte_data) % 4 != 0:
  54. raise ValueError
  55. data = np.frombuffer(byte_data, dtype='<f')
  56. data = np.reshape(data, (-1, self.n_channel))
  57. # from uV to V, ignore event channel
  58. data[:, :-1] *= 1e-6
  59. return data
  60. def __run_forever(self):
  61. self.__datathread.start()
  62. def get_trial_data(self, clear=False):
  63. """
  64. called to copy trial data from buffer
  65. :args
  66. clear (bool):
  67. :return:
  68. samplerate: number, samplerate
  69. events: ndarray (n_events, 3), [onset, duration, event_label]
  70. data: ndarray with shape of (channels, timesteps)
  71. """
  72. self.lock.acquire()
  73. data = self.buffer.copy()
  74. self.lock.release()
  75. data = np.array(data)
  76. trigger_channel = data[:, -1]
  77. onset = np.flatnonzero(trigger_channel)
  78. event_label = trigger_channel[onset]
  79. events = np.stack((onset, np.zeros_like(onset), event_label), axis=1)
  80. if clear:
  81. self.buffer.clear()
  82. return self.samplerate, events, data[:, :-1].T
  83. class OnlineHPFilter:
  84. def __init__(self, freq=1, fs=1000):
  85. self.sos = signal.butter(2, freq, btype='hp', fs=fs, output='sos')
  86. self._z = None
  87. def filter_incoming(self, data):
  88. """
  89. Args:
  90. data (ndarray): (n_times, n_chs)
  91. Returns:
  92. y (ndarray): (n_times, n_chs)
  93. """
  94. if self._z is None:
  95. self._z = np.zeros((self.sos.shape[0], 2, data.shape[1]))
  96. y, self._z = signal.sosfilt(self.sos, data, axis=0, zi=self._z)
  97. return y