-
Notifications
You must be signed in to change notification settings - Fork 252
Add CUDA health checking to /healthz endpoint #2204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hansent
wants to merge
2
commits into
main
Choose a base branch
from
cuda-health-check
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+249
−2
Open
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| """CUDA health checking utilities. | ||
|
|
||
| Provides a fast, cached health check for GPU/CUDA state. Once CUDA fails, | ||
| the context is permanently corrupted and cannot recover without process restart. | ||
| The failure state is cached to avoid repeatedly calling into a broken CUDA runtime. | ||
| """ | ||
|
|
||
| import logging | ||
| import threading | ||
| import time | ||
| from typing import Optional, Tuple | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CudaHealthChecker: | ||
| """Thread-safe CUDA health checker with failure caching. | ||
|
|
||
| Once a CUDA failure is detected, the result is cached permanently | ||
| (CUDA context corruption is unrecoverable). Subsequent calls return | ||
| the cached failure immediately without touching CUDA. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._lock = threading.Lock() | ||
| self._cuda_failed: bool = False | ||
| self._failure_error: Optional[str] = None | ||
| self._failure_time: Optional[float] = None | ||
| self._gpu_available: Optional[bool] = None # None = not yet checked | ||
|
|
||
| def _is_gpu_environment(self) -> bool: | ||
| """Check if we're running in a GPU environment. Cached after first call.""" | ||
| if self._gpu_available is not None: | ||
| return self._gpu_available | ||
| try: | ||
| import torch | ||
|
|
||
| self._gpu_available = torch.cuda.is_available() | ||
| except ImportError: | ||
| self._gpu_available = False | ||
| except Exception: | ||
| self._gpu_available = False | ||
| return self._gpu_available | ||
|
|
||
| def check_health(self) -> Tuple[bool, Optional[str]]: | ||
| """Check CUDA health. Returns (is_healthy, error_message). | ||
|
|
||
| - If not a GPU environment: returns (True, None) immediately | ||
| - If CUDA previously failed: returns cached failure immediately | ||
| - Otherwise: runs synchronize + mem_get_info check | ||
|
|
||
| Thread-safe. The actual CUDA check is serialized by the lock to | ||
| prevent concurrent CUDA calls during health checking. | ||
| """ | ||
| # Fast path: not a GPU environment | ||
| if not self._is_gpu_environment(): | ||
| return True, None | ||
|
|
||
| # Fast path: already known to be failed (unrecoverable) | ||
| if self._cuda_failed: | ||
| return False, self._failure_error | ||
|
|
||
| # Slow path: actually check CUDA | ||
| with self._lock: | ||
| # Double-check after acquiring lock | ||
| if self._cuda_failed: | ||
| return False, self._failure_error | ||
|
|
||
| try: | ||
| import torch | ||
|
|
||
| # Synchronize to surface any pending async CUDA errors | ||
| torch.cuda.synchronize() | ||
| # Query runtime to verify it's still functional | ||
| torch.cuda.mem_get_info() | ||
| return True, None | ||
| except Exception as e: | ||
| error_msg = f"CUDA health check failed: {e}" | ||
| logger.error(error_msg) | ||
| self._cuda_failed = True | ||
| self._failure_error = error_msg | ||
| self._failure_time = time.time() | ||
| return False, error_msg | ||
|
|
||
| @property | ||
| def is_failed(self) -> bool: | ||
| return self._cuda_failed | ||
|
|
||
| @property | ||
| def failure_info(self) -> Optional[dict]: | ||
| if not self._cuda_failed: | ||
| return None | ||
| return { | ||
| "error": self._failure_error, | ||
| "failed_at": self._failure_time, | ||
| } | ||
|
|
||
|
|
||
| # Module-level singleton | ||
| _checker = CudaHealthChecker() | ||
|
|
||
|
|
||
| def check_cuda_health() -> Tuple[bool, Optional[str]]: | ||
| """Module-level convenience function.""" | ||
| return _checker.check_health() | ||
|
|
||
|
|
||
| def get_cuda_health_checker() -> CudaHealthChecker: | ||
| """Return the singleton for dependency injection / testing.""" | ||
| return _checker |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
|
|
||
| from inference.core.utils.cuda_health import CudaHealthChecker | ||
|
|
||
|
|
||
| class TestCudaHealthChecker: | ||
| def setup_method(self): | ||
| """Create a fresh checker for each test.""" | ||
| self.checker = CudaHealthChecker() | ||
|
|
||
| def test_cpu_environment_no_torch(self): | ||
| """When torch is not installed, should always return healthy.""" | ||
| with patch.dict("sys.modules", {"torch": None}): | ||
| self.checker._gpu_available = None # reset cache | ||
| is_healthy, error = self.checker.check_health() | ||
| assert is_healthy is True | ||
| assert error is None | ||
|
|
||
| def test_cpu_environment_no_cuda(self): | ||
| """When torch is available but CUDA is not, should return healthy.""" | ||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = False | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| self.checker._gpu_available = None | ||
| is_healthy, error = self.checker.check_health() | ||
| assert is_healthy is True | ||
| assert error is None | ||
|
|
||
| def test_healthy_gpu(self): | ||
| """When CUDA operations succeed, should return healthy.""" | ||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = True | ||
| mock_torch.cuda.synchronize.return_value = None | ||
| mock_torch.cuda.mem_get_info.return_value = (4_000_000_000, 8_000_000_000) | ||
|
|
||
| self.checker._gpu_available = True | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| is_healthy, error = self.checker.check_health() | ||
| assert is_healthy is True | ||
| assert error is None | ||
| mock_torch.cuda.synchronize.assert_called_once() | ||
| mock_torch.cuda.mem_get_info.assert_called_once() | ||
|
|
||
| def test_cuda_synchronize_failure(self): | ||
| """When torch.cuda.synchronize() fails, should detect CUDA corruption.""" | ||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = True | ||
| mock_torch.cuda.synchronize.side_effect = RuntimeError( | ||
| "CUDA error: an illegal memory access was encountered" | ||
| ) | ||
|
|
||
| self.checker._gpu_available = True | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| is_healthy, error = self.checker.check_health() | ||
| assert is_healthy is False | ||
| assert "illegal memory access" in error | ||
| assert self.checker.is_failed is True | ||
|
|
||
| def test_mem_get_info_failure(self): | ||
| """When mem_get_info fails (after synchronize succeeds), should detect failure.""" | ||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = True | ||
| mock_torch.cuda.synchronize.return_value = None | ||
| mock_torch.cuda.mem_get_info.side_effect = RuntimeError("CUDA runtime error") | ||
|
|
||
| self.checker._gpu_available = True | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| is_healthy, error = self.checker.check_health() | ||
| assert is_healthy is False | ||
| assert "CUDA runtime error" in error | ||
|
|
||
| def test_failure_is_cached(self): | ||
| """After first CUDA failure, subsequent checks should return cached failure | ||
| without calling torch again.""" | ||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = True | ||
| mock_torch.cuda.synchronize.side_effect = RuntimeError("CUDA error") | ||
|
|
||
| self.checker._gpu_available = True | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| # First call: detects failure | ||
| is_healthy1, error1 = self.checker.check_health() | ||
| assert is_healthy1 is False | ||
| assert mock_torch.cuda.synchronize.call_count == 1 | ||
|
|
||
| # Second call: returns cached failure, no new CUDA calls | ||
| mock_torch.cuda.synchronize.reset_mock() | ||
| is_healthy2, error2 = self.checker.check_health() | ||
| assert is_healthy2 is False | ||
| assert error2 == error1 | ||
| mock_torch.cuda.synchronize.assert_not_called() | ||
|
|
||
| def test_failure_info(self): | ||
| """failure_info should return error details after failure.""" | ||
| assert self.checker.failure_info is None | ||
|
|
||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = True | ||
| mock_torch.cuda.synchronize.side_effect = RuntimeError("CUDA error") | ||
|
|
||
| self.checker._gpu_available = True | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| self.checker.check_health() | ||
|
|
||
| info = self.checker.failure_info | ||
| assert info is not None | ||
| assert "CUDA error" in info["error"] | ||
| assert info["failed_at"] is not None | ||
|
|
||
| def test_gpu_available_is_cached(self): | ||
| """_is_gpu_environment() should only check torch once.""" | ||
| mock_torch = MagicMock() | ||
| mock_torch.cuda.is_available.return_value = False | ||
| with patch.dict("sys.modules", {"torch": mock_torch}): | ||
| assert self.checker._is_gpu_environment() is False | ||
| assert self.checker._is_gpu_environment() is False | ||
| # Only called once despite two invocations | ||
| mock_torch.cuda.is_available.assert_called_once() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.