diff --git a/nemo_rl/utils/nvml.py b/nemo_rl/utils/nvml.py index 706c1b1cd6..de9ea55f63 100644 --- a/nemo_rl/utils/nvml.py +++ b/nemo_rl/utils/nvml.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import logging import os from typing import Generator import pynvml +logger = logging.getLogger(__name__) + @contextlib.contextmanager def nvml_context() -> Generator[None, None, None]: @@ -78,13 +81,16 @@ def get_device_uuid(device_idx: int) -> str: def get_free_memory_bytes(device_idx: int) -> float: - """Get the free memory of a CUDA device in bytes using NVML.""" + """Get the free memory of a CUDA device in bytes using NVML, with torch.cuda fallback.""" global_device_idx = device_id_to_physical_device_id(device_idx) with nvml_context(): try: handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx) return pynvml.nvmlDeviceGetMemoryInfo(handle).free except pynvml.NVMLError as e: - raise RuntimeError( - f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}" - ) + logger.warning("NVML memory query failed for device %d: %s. Falling back to torch.cuda.mem_get_info.", device_idx, e) + + import torch + + free, _total = torch.cuda.mem_get_info(device_idx) + return free