diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py index 0b17916274..50027fd02e 100644 --- a/inference/core/interfaces/http/http_api.py +++ b/inference/core/interfaces/http/http_api.py @@ -2200,8 +2200,25 @@ def readiness( @app.get("/healthz", status_code=200) def healthz(): - """Health endpoint for Kubernetes liveness probe.""" - return {"status": "healthy"} + """Health endpoint for Kubernetes liveness probe. + + Verifies CUDA context health when running on GPU. Returns 503 if + CUDA is corrupted (unrecoverable - requires process restart). + """ + from inference.core.utils.cuda_health import check_cuda_health + + is_healthy, error = check_cuda_health() + if is_healthy: + return {"status": "healthy"} + else: + logger.error("CUDA health check failed: %s", error) + return JSONResponse( + content={ + "status": "unhealthy", + "reason": "cuda_error", + }, + status_code=503, + ) if CORE_MODELS_ENABLED: if CORE_MODEL_CLIP_ENABLED: diff --git a/inference/core/utils/cuda_health.py b/inference/core/utils/cuda_health.py new file mode 100644 index 0000000000..eeef8fde0c --- /dev/null +++ b/inference/core/utils/cuda_health.py @@ -0,0 +1,110 @@ +"""CUDA health checking utilities. + +Provides a fast, cached health check for GPU/CUDA state. Once CUDA fails, +the context is permanently corrupted and cannot recover without process restart. +The failure state is cached to avoid repeatedly calling into a broken CUDA runtime. +""" + +import logging +import threading +import time +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + + +class CudaHealthChecker: + """Thread-safe CUDA health checker with failure caching. + + Once a CUDA failure is detected, the result is cached permanently + (CUDA context corruption is unrecoverable). Subsequent calls return + the cached failure immediately without touching CUDA. + """ + + def __init__(self): + self._lock = threading.Lock() + self._cuda_failed: bool = False + self._failure_error: Optional[str] = None + self._failure_time: Optional[float] = None + self._gpu_available: Optional[bool] = None # None = not yet checked + + def _is_gpu_environment(self) -> bool: + """Check if we're running in a GPU environment. Cached after first call.""" + if self._gpu_available is not None: + return self._gpu_available + try: + import torch + + self._gpu_available = torch.cuda.is_available() + except ImportError: + self._gpu_available = False + except Exception: + self._gpu_available = False + return self._gpu_available + + def check_health(self) -> Tuple[bool, Optional[str]]: + """Check CUDA health. Returns (is_healthy, error_message). + + - If not a GPU environment: returns (True, None) immediately + - If CUDA previously failed: returns cached failure immediately + - Otherwise: runs synchronize + mem_get_info check + + Thread-safe. The actual CUDA check is serialized by the lock to + prevent concurrent CUDA calls during health checking. + """ + # Fast path: not a GPU environment + if not self._is_gpu_environment(): + return True, None + + # Fast path: already known to be failed (unrecoverable) + if self._cuda_failed: + return False, self._failure_error + + # Slow path: actually check CUDA + with self._lock: + # Double-check after acquiring lock + if self._cuda_failed: + return False, self._failure_error + + try: + import torch + + # Synchronize to surface any pending async CUDA errors + torch.cuda.synchronize() + # Query runtime to verify it's still functional + torch.cuda.mem_get_info() + return True, None + except Exception as e: + error_msg = f"CUDA health check failed: {e}" + logger.error(error_msg) + self._cuda_failed = True + self._failure_error = error_msg + self._failure_time = time.time() + return False, error_msg + + @property + def is_failed(self) -> bool: + return self._cuda_failed + + @property + def failure_info(self) -> Optional[dict]: + if not self._cuda_failed: + return None + return { + "error": self._failure_error, + "failed_at": self._failure_time, + } + + +# Module-level singleton +_checker = CudaHealthChecker() + + +def check_cuda_health() -> Tuple[bool, Optional[str]]: + """Module-level convenience function.""" + return _checker.check_health() + + +def get_cuda_health_checker() -> CudaHealthChecker: + """Return the singleton for dependency injection / testing.""" + return _checker diff --git a/tests/unit/core/test_cuda_health.py b/tests/unit/core/test_cuda_health.py new file mode 100644 index 0000000000..da4e56cc99 --- /dev/null +++ b/tests/unit/core/test_cuda_health.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from inference.core.utils.cuda_health import CudaHealthChecker + + +class TestCudaHealthChecker: + def setup_method(self): + """Create a fresh checker for each test.""" + self.checker = CudaHealthChecker() + + def test_cpu_environment_no_torch(self): + """When torch is not installed, should always return healthy.""" + with patch.dict("sys.modules", {"torch": None}): + self.checker._gpu_available = None # reset cache + is_healthy, error = self.checker.check_health() + assert is_healthy is True + assert error is None + + def test_cpu_environment_no_cuda(self): + """When torch is available but CUDA is not, should return healthy.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = False + with patch.dict("sys.modules", {"torch": mock_torch}): + self.checker._gpu_available = None + is_healthy, error = self.checker.check_health() + assert is_healthy is True + assert error is None + + def test_healthy_gpu(self): + """When CUDA operations succeed, should return healthy.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.synchronize.return_value = None + mock_torch.cuda.mem_get_info.return_value = (4_000_000_000, 8_000_000_000) + + self.checker._gpu_available = True + with patch.dict("sys.modules", {"torch": mock_torch}): + is_healthy, error = self.checker.check_health() + assert is_healthy is True + assert error is None + mock_torch.cuda.synchronize.assert_called_once() + mock_torch.cuda.mem_get_info.assert_called_once() + + def test_cuda_synchronize_failure(self): + """When torch.cuda.synchronize() fails, should detect CUDA corruption.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.synchronize.side_effect = RuntimeError( + "CUDA error: an illegal memory access was encountered" + ) + + self.checker._gpu_available = True + with patch.dict("sys.modules", {"torch": mock_torch}): + is_healthy, error = self.checker.check_health() + assert is_healthy is False + assert "illegal memory access" in error + assert self.checker.is_failed is True + + def test_mem_get_info_failure(self): + """When mem_get_info fails (after synchronize succeeds), should detect failure.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.synchronize.return_value = None + mock_torch.cuda.mem_get_info.side_effect = RuntimeError("CUDA runtime error") + + self.checker._gpu_available = True + with patch.dict("sys.modules", {"torch": mock_torch}): + is_healthy, error = self.checker.check_health() + assert is_healthy is False + assert "CUDA runtime error" in error + + def test_failure_is_cached(self): + """After first CUDA failure, subsequent checks should return cached failure + without calling torch again.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.synchronize.side_effect = RuntimeError("CUDA error") + + self.checker._gpu_available = True + with patch.dict("sys.modules", {"torch": mock_torch}): + # First call: detects failure + is_healthy1, error1 = self.checker.check_health() + assert is_healthy1 is False + assert mock_torch.cuda.synchronize.call_count == 1 + + # Second call: returns cached failure, no new CUDA calls + mock_torch.cuda.synchronize.reset_mock() + is_healthy2, error2 = self.checker.check_health() + assert is_healthy2 is False + assert error2 == error1 + mock_torch.cuda.synchronize.assert_not_called() + + def test_failure_info(self): + """failure_info should return error details after failure.""" + assert self.checker.failure_info is None + + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = True + mock_torch.cuda.synchronize.side_effect = RuntimeError("CUDA error") + + self.checker._gpu_available = True + with patch.dict("sys.modules", {"torch": mock_torch}): + self.checker.check_health() + + info = self.checker.failure_info + assert info is not None + assert "CUDA error" in info["error"] + assert info["failed_at"] is not None + + def test_gpu_available_is_cached(self): + """_is_gpu_environment() should only check torch once.""" + mock_torch = MagicMock() + mock_torch.cuda.is_available.return_value = False + with patch.dict("sys.modules", {"torch": mock_torch}): + assert self.checker._is_gpu_environment() is False + assert self.checker._is_gpu_environment() is False + # Only called once despite two invocations + mock_torch.cuda.is_available.assert_called_once()