Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,8 +2200,25 @@ def readiness(

@app.get("/healthz", status_code=200)
def healthz():
"""Health endpoint for Kubernetes liveness probe."""
return {"status": "healthy"}
"""Health endpoint for Kubernetes liveness probe.

Verifies CUDA context health when running on GPU. Returns 503 if
CUDA is corrupted (unrecoverable - requires process restart).
"""
from inference.core.utils.cuda_health import check_cuda_health

is_healthy, error = check_cuda_health()
if is_healthy:
return {"status": "healthy"}
else:
logger.error("CUDA health check failed: %s", error)
return JSONResponse(
content={
"status": "unhealthy",
"reason": "cuda_error",
},
status_code=503,
)

if CORE_MODELS_ENABLED:
if CORE_MODEL_CLIP_ENABLED:
Expand Down
110 changes: 110 additions & 0 deletions inference/core/utils/cuda_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""CUDA health checking utilities.

Provides a fast, cached health check for GPU/CUDA state. Once CUDA fails,
the context is permanently corrupted and cannot recover without process restart.
The failure state is cached to avoid repeatedly calling into a broken CUDA runtime.
"""

import logging
import threading
import time
from typing import Optional, Tuple

logger = logging.getLogger(__name__)


class CudaHealthChecker:
"""Thread-safe CUDA health checker with failure caching.

Once a CUDA failure is detected, the result is cached permanently
(CUDA context corruption is unrecoverable). Subsequent calls return
the cached failure immediately without touching CUDA.
"""

def __init__(self):
self._lock = threading.Lock()
self._cuda_failed: bool = False
self._failure_error: Optional[str] = None
self._failure_time: Optional[float] = None
self._gpu_available: Optional[bool] = None # None = not yet checked

def _is_gpu_environment(self) -> bool:
"""Check if we're running in a GPU environment. Cached after first call."""
if self._gpu_available is not None:
return self._gpu_available
try:
import torch

self._gpu_available = torch.cuda.is_available()
except ImportError:
self._gpu_available = False
except Exception:
self._gpu_available = False
return self._gpu_available

def check_health(self) -> Tuple[bool, Optional[str]]:
"""Check CUDA health. Returns (is_healthy, error_message).

- If not a GPU environment: returns (True, None) immediately
- If CUDA previously failed: returns cached failure immediately
- Otherwise: runs synchronize + mem_get_info check

Thread-safe. The actual CUDA check is serialized by the lock to
prevent concurrent CUDA calls during health checking.
"""
# Fast path: not a GPU environment
if not self._is_gpu_environment():
return True, None

# Fast path: already known to be failed (unrecoverable)
if self._cuda_failed:
return False, self._failure_error

# Slow path: actually check CUDA
with self._lock:
# Double-check after acquiring lock
if self._cuda_failed:
return False, self._failure_error

try:
import torch

# Synchronize to surface any pending async CUDA errors
torch.cuda.synchronize()
# Query runtime to verify it's still functional
torch.cuda.mem_get_info()
return True, None
except Exception as e:
error_msg = f"CUDA health check failed: {e}"
logger.error(error_msg)
self._cuda_failed = True
self._failure_error = error_msg
self._failure_time = time.time()
return False, error_msg

@property
def is_failed(self) -> bool:
return self._cuda_failed

@property
def failure_info(self) -> Optional[dict]:
if not self._cuda_failed:
return None
return {
"error": self._failure_error,
"failed_at": self._failure_time,
}


# Module-level singleton
_checker = CudaHealthChecker()


def check_cuda_health() -> Tuple[bool, Optional[str]]:
"""Module-level convenience function."""
return _checker.check_health()


def get_cuda_health_checker() -> CudaHealthChecker:
"""Return the singleton for dependency injection / testing."""
return _checker
120 changes: 120 additions & 0 deletions tests/unit/core/test_cuda_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from unittest.mock import MagicMock, patch

import pytest

from inference.core.utils.cuda_health import CudaHealthChecker


class TestCudaHealthChecker:
def setup_method(self):
"""Create a fresh checker for each test."""
self.checker = CudaHealthChecker()

def test_cpu_environment_no_torch(self):
"""When torch is not installed, should always return healthy."""
with patch.dict("sys.modules", {"torch": None}):
self.checker._gpu_available = None # reset cache
is_healthy, error = self.checker.check_health()
assert is_healthy is True
assert error is None

def test_cpu_environment_no_cuda(self):
"""When torch is available but CUDA is not, should return healthy."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = False
with patch.dict("sys.modules", {"torch": mock_torch}):
self.checker._gpu_available = None
is_healthy, error = self.checker.check_health()
assert is_healthy is True
assert error is None

def test_healthy_gpu(self):
"""When CUDA operations succeed, should return healthy."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.synchronize.return_value = None
mock_torch.cuda.mem_get_info.return_value = (4_000_000_000, 8_000_000_000)

self.checker._gpu_available = True
with patch.dict("sys.modules", {"torch": mock_torch}):
is_healthy, error = self.checker.check_health()
assert is_healthy is True
assert error is None
mock_torch.cuda.synchronize.assert_called_once()
mock_torch.cuda.mem_get_info.assert_called_once()

def test_cuda_synchronize_failure(self):
"""When torch.cuda.synchronize() fails, should detect CUDA corruption."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.synchronize.side_effect = RuntimeError(
"CUDA error: an illegal memory access was encountered"
)

self.checker._gpu_available = True
with patch.dict("sys.modules", {"torch": mock_torch}):
is_healthy, error = self.checker.check_health()
assert is_healthy is False
assert "illegal memory access" in error
assert self.checker.is_failed is True

def test_mem_get_info_failure(self):
"""When mem_get_info fails (after synchronize succeeds), should detect failure."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.synchronize.return_value = None
mock_torch.cuda.mem_get_info.side_effect = RuntimeError("CUDA runtime error")

self.checker._gpu_available = True
with patch.dict("sys.modules", {"torch": mock_torch}):
is_healthy, error = self.checker.check_health()
assert is_healthy is False
assert "CUDA runtime error" in error

def test_failure_is_cached(self):
"""After first CUDA failure, subsequent checks should return cached failure
without calling torch again."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.synchronize.side_effect = RuntimeError("CUDA error")

self.checker._gpu_available = True
with patch.dict("sys.modules", {"torch": mock_torch}):
# First call: detects failure
is_healthy1, error1 = self.checker.check_health()
assert is_healthy1 is False
assert mock_torch.cuda.synchronize.call_count == 1

# Second call: returns cached failure, no new CUDA calls
mock_torch.cuda.synchronize.reset_mock()
is_healthy2, error2 = self.checker.check_health()
assert is_healthy2 is False
assert error2 == error1
mock_torch.cuda.synchronize.assert_not_called()

def test_failure_info(self):
"""failure_info should return error details after failure."""
assert self.checker.failure_info is None

mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = True
mock_torch.cuda.synchronize.side_effect = RuntimeError("CUDA error")

self.checker._gpu_available = True
with patch.dict("sys.modules", {"torch": mock_torch}):
self.checker.check_health()

info = self.checker.failure_info
assert info is not None
assert "CUDA error" in info["error"]
assert info["failed_at"] is not None

def test_gpu_available_is_cached(self):
"""_is_gpu_environment() should only check torch once."""
mock_torch = MagicMock()
mock_torch.cuda.is_available.return_value = False
with patch.dict("sys.modules", {"torch": mock_torch}):
assert self.checker._is_gpu_environment() is False
assert self.checker._is_gpu_environment() is False
# Only called once despite two invocations
mock_torch.cuda.is_available.assert_called_once()
Loading