Browse Source

Feat: 调试实现视频功能

DESKTOP-4GKCI80\Neuracle 1 year ago
parent
commit
5b972e1760
2 changed files with 84 additions and 16 deletions
  1. 23 16
      backend/device/video.py
  2. 61 0
      backend/tests/test_video.py

+ 23 - 16
backend/device/video.py

@@ -1,9 +1,9 @@
 """ Common function for camera based method """
-from fractions import Fraction
 import json
 import logging
 import os
 import time
+from datetime import datetime
 import threading
 
 import cv2
@@ -15,17 +15,27 @@ logger = logging.getLogger(__name__)
 
 
 class VideoCaptureThread:
-    def __init__(self, output_path, video_source=0, sync_device=None):
+    def __init__(self, output_dir, video_source=0, sync_device=None):
         super(VideoCaptureThread, self).__init__()
         self.video_source = video_source
-        self.cap = cv2.VideoCapture(self.video_source)
+        self.cap = cv2.VideoCapture(self.video_source, cv2.CAP_DSHOW)
+        while not self.cap.isOpened():
+            pass  # Wait for the capture to be initialized
+        self.cap.set(cv2.CAP_PROP_FPS, 30.0)
+        self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280),
+        self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
         
         self.sync_device = sync_device
         
-        self.output_path = output_path
-        frame_size = (int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
-                    int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
-        self.out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'avc1'), 20.0, frame_size)
+        now = datetime.now()
+        date_time_str = now.strftime("%m-%d-%Y-%H-%M-%S")
+        try:
+            os.makedirs(output_dir)
+        except FileExistsError:
+            pass
+        self.output_path = os.path.join(output_dir, f'video_recording_{date_time_str}.mp4')
+        
+        self.out = cv2.VideoWriter(self.output_path, cv2.VideoWriter_fourcc(*'mp4v'), 30.0, (1280, 720))
 
         self.videothread = threading.Thread(target=self.run)
         self.videothread.start()
@@ -35,22 +45,19 @@ class VideoCaptureThread:
         self.capture_video()
     
     def capture_video(self):
-        while not self.cap.isOpened():
-            pass  # Wait for the capture to be initialized
-        
-        # synchronize
+        ret, frame = self.cap.read()
+        if not ret:
+            logger.error("Error: Couldn't read frame. Exit.")
+            return
+        # synchronize after getting the first frame
         if self.sync_device is not None:
             self.sync_device.send_trigger(0xff)  # 255 for video ready
 
-        while True:
+        while self.cap.isOpened():
             ret, frame = self.cap.read()
 
             # TODO: online analysis  (500ms step, asychronize)
 
-            if not ret:
-                logger.error("Error: Couldn't read frame. Exit.")
-                break
-
             self.out.write(frame)
 
     def close(self):

+ 61 - 0
backend/tests/test_video.py

@@ -0,0 +1,61 @@
+import unittest
+import time
+import os
+import shutil
+from glob import glob
+
+import cv2
+import numpy as np
+
+from device.video import VideoCaptureThread
+from device.trigger_box import TriggerNeuracle
+from device.data_client import NeuracleDataClient
+
+
+def get_video_length(file_path):
+    cap = cv2.VideoCapture(file_path)
+
+    if not cap.isOpened():
+        return None
+
+    # Get the frames per second (fps) and total number of frames
+    fps = cap.get(cv2.CAP_PROP_FPS)
+    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+    # Calculate the duration of the video in seconds
+    duration = total_frames / fps
+
+    # Release the video capture object
+    cap.release()
+
+    return duration
+
+
+class TestVideo(unittest.TestCase):
+    def test_video_recording(self):
+        output_dir = './tests/data/video'
+        video_cam = VideoCaptureThread(output_dir=output_dir, video_source=1)
+        time.sleep(10)
+        video_cam.close()
+
+        # read video files
+        file = glob(os.path.join(output_dir, '*.mp4'))[0]
+        duration = get_video_length(file)
+        self.assertTrue(isinstance(duration, float))
+        self.assertTrue(duration > 0)
+
+        shutil.rmtree(output_dir)
+    
+    def test_video_sync(self):
+        output_dir = './tests/data/video'
+        trigger = TriggerNeuracle()
+        data_client = NeuracleDataClient(buffer_len=10.)
+        video_cam = VideoCaptureThread(output_dir=output_dir, video_source=1, sync_device=trigger)
+        time.sleep(5)
+        video_cam.close()
+        events = data_client.get_trial_data(clear=True)[1]
+        self.assertEqual(len(events), 1)
+        self.assertAlmostEqual(events[0, 2], 255)
+        data_client.close()
+
+        shutil.rmtree(output_dir)