Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions nemo_rl/utils/nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import os
from typing import Generator

import pynvml

logger = logging.getLogger(__name__)


@contextlib.contextmanager
def nvml_context() -> Generator[None, None, None]:
Expand Down Expand Up @@ -78,13 +81,16 @@ def get_device_uuid(device_idx: int) -> str:


def get_free_memory_bytes(device_idx: int) -> float:
"""Get the free memory of a CUDA device in bytes using NVML."""
"""Get the free memory of a CUDA device in bytes using NVML, with torch.cuda fallback."""
global_device_idx = device_id_to_physical_device_id(device_idx)
with nvml_context():
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx)
return pynvml.nvmlDeviceGetMemoryInfo(handle).free
except pynvml.NVMLError as e:
raise RuntimeError(
f"Failed to get free memory for device {device_idx} (global index: {global_device_idx}): {e}"
)
logger.warning("NVML memory query failed for device %d: %s. Falling back to torch.cuda.mem_get_info.", device_idx, e)

import torch

free, _total = torch.cuda.mem_get_info(device_idx)
return free
Loading