data_client.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import socket
  2. import threading
  3. import numpy as np
  4. class NeuracleDataClient:
  5. UPDATE_INTERVAL = 0.04
  6. BYTES_PER_NUM = 4
  7. BUFFER_LEN = 1 # in secondes
  8. def __init__(self, n_channel=9, samplerate=1000, host='localhost', port=8712):
  9. self.n_channel = n_channel
  10. self.chunk_size = int(self.UPDATE_INTERVAL * samplerate * self.BYTES_PER_NUM * n_channel)
  11. self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
  12. self.buffer = []
  13. self.max_buffer_length = int(self.BUFFER_LEN / self.UPDATE_INTERVAL)
  14. self._host = host
  15. self._port = port
  16. # thread lock
  17. self.lock = threading.Lock()
  18. self.__datathread = threading.Thread(target=self.__recv_loop)
  19. self.samplerate = samplerate
  20. # start client
  21. self.__config()
  22. def __config(self):
  23. self.__sock.connect((self._host, self._port))
  24. self.__run_forever()
  25. def is_active(self):
  26. return self.__sock.fileno() != -1
  27. def close(self):
  28. self.__sock.close()
  29. self.__datathread.join()
  30. def __recv_loop(self):
  31. while self.__sock.fileno() != -1:
  32. try:
  33. data = self.__sock.recv(self.chunk_size)
  34. except OSError:
  35. break
  36. if len(data) % 4 != 0:
  37. continue
  38. self.lock.acquire()
  39. self.buffer.append(data)
  40. # remove old data
  41. if len(self.buffer) > self.max_buffer_length:
  42. del self.buffer[0]
  43. self.lock.release()
  44. def __run_forever(self):
  45. self.__datathread.start()
  46. def get_trial_data(self, clear=False):
  47. """
  48. called to copy trial data from buffer
  49. :args
  50. clear (bool):
  51. :return:
  52. samplerate: number, samplerate
  53. events: ndarray (n_events, 3), [onset, duration, event_label]
  54. data: ndarray with shape of (channels, timesteps)
  55. """
  56. self.lock.acquire()
  57. raw_data = self.buffer.copy()
  58. self.lock.release()
  59. total_data = b''.join(raw_data)
  60. byte_data = bytearray(total_data)
  61. if len(byte_data) % 4 != 0:
  62. raise ValueError
  63. data = np.frombuffer(byte_data, dtype='<f')
  64. data = np.reshape(data, (-1, self.n_channel))
  65. trigger_channel = data[:, -1]
  66. onset = np.flatnonzero(trigger_channel)
  67. event_label = trigger_channel[onset]
  68. events = np.stack((onset, np.zeros_like(onset), event_label), axis=1)
  69. if clear:
  70. self.buffer.clear()
  71. return self.samplerate, events, data[:, :-1].T