diff --git a/src/inputs/plugins/vlm_coco_local.py b/src/inputs/plugins/vlm_coco_local.py index ee52fc800..22a615ebe 100644 --- a/src/inputs/plugins/vlm_coco_local.py +++ b/src/inputs/plugins/vlm_coco_local.py @@ -69,7 +69,8 @@ def __init__(self, config: VLM_COCO_LocalConfig): """ super().__init__(config) - self.device = "cpu" + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logging.info(f"COCO Object Detector using device: {self.device}") self.detection_threshold = 0.2 self.camera_index = self.config.camera_index diff --git a/tests/inputs/plugins/test_vlm_coco_local.py b/tests/inputs/plugins/test_vlm_coco_local.py index 9542f1ca7..40cd21f34 100644 --- a/tests/inputs/plugins/test_vlm_coco_local.py +++ b/tests/inputs/plugins/test_vlm_coco_local.py @@ -38,8 +38,11 @@ def mock_cv2_video_capture(): @pytest.fixture def vlm_coco_local(mock_model, mock_check_webcam, mock_cv2_video_capture): - config = VLM_COCO_LocalConfig(camera_index=0) - return VLM_COCO_Local(config=config) + with patch( + "inputs.plugins.vlm_coco_local.torch.cuda.is_available", return_value=False + ): + config = VLM_COCO_LocalConfig(camera_index=0) + return VLM_COCO_Local(config=config) @pytest.mark.asyncio @@ -67,6 +70,7 @@ async def test_raw_to_text_with_detection(vlm_coco_local, mock_model): vlm_coco_local.class_labels = ["__background__", "cat", "dog"] vlm_coco_local.cam_third = 213 # 640/3 + vlm_coco_local.device = "cpu" dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8) result = await vlm_coco_local._raw_to_text(dummy_frame) assert isinstance(result, Message) @@ -99,8 +103,10 @@ async def test_poll_returns_none_on_failed_frame_read( ): """Test that _poll returns None when cap.read() fails.""" mock_cv2_video_capture.read.return_value = (False, None) - config = VLM_COCO_LocalConfig(camera_index=0) - sensor = VLM_COCO_Local(config=config) - + with patch( + "inputs.plugins.vlm_coco_local.torch.cuda.is_available", return_value=False + ): + config = VLM_COCO_LocalConfig(camera_index=0) + sensor = VLM_COCO_Local(config=config) result = await sensor._poll() assert result is None