Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ policy:
precision: "bfloat16"
logprob_chunk_size: null
offload_optimizer_for_logprob: false # Only useful for non-colocated generation since colocated generation will always offload optimizer to cuda before refit
use_pinned_optimizer_offload: false # Use pinned memory for optimizer D2H/H2D transfers

dtensor_cfg:
_v2: true
Expand Down
98 changes: 84 additions & 14 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,22 +1255,92 @@ def move_optimizer(self, device: str):
optimizer_state = self.optimizer.state
else:
optimizer_state = self.optimizer._get_state()

use_pinned = self.cfg.get("use_pinned_optimizer_offload", False)

if device == "cpu":
if use_pinned:
self._coalesced_optimizer_to_cpu(optimizer_state)
else:
self._optimizer_to_cpu(optimizer_state)
elif device == "cuda":
if use_pinned:
self._coalesced_optimizer_to_cuda(optimizer_state)
else:
self._optimizer_to_cuda(optimizer_state)
else:
raise ValueError(
f"Invalid device: {device}. Only strings 'cpu' and 'cuda' are supported."
)

def _optimizer_to_cpu(self, optimizer_state):
"""Offload optimizer state tensors to CPU using default pageable memory."""
for _, state in optimizer_state.items():
# Iterate through the state items (e.g., momentum, variance) for a parameter
for k, v in state.items():
# Check if the item is a tensor
if torch.is_tensor(v):
# Move the tensor to device and update the state dictionary
if device == "cpu":
if v.is_cuda:
state[k] = v.to("cpu")
elif device == "cuda":
if not v.is_cuda:
state[k] = v.to("cuda")
else:
raise ValueError(
f"Invalid device: {device}. Only strings 'cpu' and 'cuda' are supported."
)
if torch.is_tensor(v) and v.is_cuda:
state[k] = v.to("cpu")

def _optimizer_to_cuda(self, optimizer_state):
"""Reload optimizer state tensors to CUDA."""
for _, state in optimizer_state.items():
for k, v in state.items():
if torch.is_tensor(v) and not v.is_cuda:
state[k] = v.to("cuda")

def _get_or_alloc_pinned_buf(
self, attr_name: str, total_bytes: int
) -> torch.Tensor:
"""Return a cached pinned CPU buffer, allocating only on first use or resize."""
buf = getattr(self, attr_name, None)
if buf is None or buf.numel() < total_bytes:
buf = torch.empty(
total_bytes, device="cpu", dtype=torch.uint8, pin_memory=True
)
setattr(self, attr_name, buf)
return buf

def _coalesced_optimizer_to_cpu(self, optimizer_state):
"""Offload all optimizer state tensors to CPU via a cached pinned buffer.

Packs all CUDA tensors into a single pre-allocated pinned CPU buffer,
eliminating per-tensor cudaHostAlloc overhead. The pinned buffer is
allocated once on first call and reused across iterations.
"""
ALIGN = 512
entries = []
total_bytes = 0

for _, state in optimizer_state.items():
for k, v in state.items():
if not torch.is_tensor(v) or not v.is_cuda:
continue
if v.dim() == 0:
state[k] = v.cpu()
continue
offset = (total_bytes + ALIGN - 1) // ALIGN * ALIGN
nbytes = v.numel() * v.element_size()
entries.append((state, k, v, offset, nbytes))
total_bytes = offset + nbytes

if not entries:
return

cpu_buf = self._get_or_alloc_pinned_buf("_optimizer_pinned_buf", total_bytes)

for state, k, v, offset, nbytes in entries:
dst = cpu_buf[offset : offset + nbytes].view(v.dtype).reshape(v.shape)
dst.copy_(v, non_blocking=True)
state[k] = dst

torch.cuda.synchronize()

def _coalesced_optimizer_to_cuda(self, optimizer_state):
"""Reload all optimizer state tensors back to CUDA."""
for _, state in optimizer_state.items():
for k, v in state.items():
if torch.is_tensor(v) and not v.is_cuda:
state[k] = v.to("cuda", non_blocking=True)
torch.cuda.synchronize()

def save_checkpoint(
self,
Expand Down
1 change: 1 addition & 0 deletions research/template_project/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ policy:
precision: "bfloat16"
logprob_chunk_size: null
offload_optimizer_for_logprob: false # Only useful for non-colocated generation since colocated generation will always offload optimizer to cuda before refit
use_pinned_optimizer_offload: false # Use pinned memory for optimizer D2H/H2D transfers

dtensor_cfg:
_v2: true
Expand Down
176 changes: 176 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2847,3 +2847,179 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path):
finally:
policy.shutdown()
cluster.shutdown()


# ---------------------------------------------------------------------------
# Pinned optimizer offload tests
# ---------------------------------------------------------------------------


def _make_optimizer_state(shapes, dtype=torch.float32, include_scalar=False):
"""Build a fake optimizer_state dict mimicking Adam (exp_avg, exp_avg_sq)."""
state = {}
for i, shape in enumerate(shapes):
param_state = {
"exp_avg": torch.randn(shape, device="cuda", dtype=dtype),
"exp_avg_sq": torch.randn(shape, device="cuda", dtype=dtype).abs(),
}
if include_scalar:
param_state["step"] = torch.tensor(10.0, device="cuda", dtype=dtype)
state[i] = param_state
return state


class _FakeOptimizer:
"""Minimal optimizer stub that exposes state via _get_state()."""

def __init__(self, state):
self._state = state

def _get_state(self):
return self._state


def _make_pinned_test_worker(use_pinned, optimizer_state):
"""Build a stub worker with only the attributes needed for optimizer offload."""
from nemo_rl.models.policy.workers.megatron_policy_worker import (
MegatronPolicyWorkerImpl,
)

worker = object.__new__(MegatronPolicyWorkerImpl)
worker.cfg = {"use_pinned_optimizer_offload": use_pinned}
worker.optimizer = _FakeOptimizer(optimizer_state)
return worker


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestOptimizerOffloadRoundtrip:
"""Verify optimizer state survives a GPU -> CPU -> GPU round-trip."""

@pytest.mark.parametrize("use_pinned", [False, True], ids=["pageable", "pinned"])
def test_values_preserved(self, use_pinned):
shapes = [(64, 128), (32,), (256, 256)]
state = _make_optimizer_state(shapes, include_scalar=True)
worker = _make_pinned_test_worker(use_pinned, state)

originals = {}
for pid, param_state in state.items():
originals[pid] = {k: v.clone() for k, v in param_state.items()}

worker.move_optimizer("cpu")
for param_state in state.values():
for v in param_state.values():
assert not v.is_cuda, "tensor should be on CPU after offload"

worker.move_optimizer("cuda")
for pid, param_state in state.items():
for k, v in param_state.items():
assert v.is_cuda, f"tensor {k} should be back on CUDA"
torch.testing.assert_close(
v, originals[pid][k], msg=lambda m: f"param {pid}/{k}: {m}"
)

@pytest.mark.parametrize("use_pinned", [False, True], ids=["pageable", "pinned"])
def test_multiple_dtypes(self, use_pinned):
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
state = _make_optimizer_state([(128,)], dtype=dtype)
worker = _make_pinned_test_worker(use_pinned, state)
original = state[0]["exp_avg"].clone()

worker.move_optimizer("cpu")
worker.move_optimizer("cuda")

torch.testing.assert_close(state[0]["exp_avg"], original)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPinnedBufferCaching:
"""Verify the pinned path allocates the buffer once and reuses it."""

def test_buffer_reused_across_calls(self):
state = _make_optimizer_state([(64, 128)])
worker = _make_pinned_test_worker(True, state)

worker.move_optimizer("cpu")
buf1 = worker._optimizer_pinned_buf

worker.move_optimizer("cuda")
worker.move_optimizer("cpu")
buf2 = worker._optimizer_pinned_buf

assert buf1.data_ptr() == buf2.data_ptr(), "pinned buffer should be reused"

def test_no_pinned_buf_when_disabled(self):
state = _make_optimizer_state([(64, 128)])
worker = _make_pinned_test_worker(False, state)

worker.move_optimizer("cpu")
assert not hasattr(worker, "_optimizer_pinned_buf")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPinnedMemoryProperties:
"""Verify CPU tensors from the pinned path are actually pinned."""

def test_cpu_tensors_are_pinned(self):
state = _make_optimizer_state([(64, 128), (256,)])
worker = _make_pinned_test_worker(True, state)

worker.move_optimizer("cpu")
for param_state in state.values():
for k, v in param_state.items():
if v.dim() > 0:
assert v.is_pinned(), f"{k} should be in pinned memory"

def test_cpu_tensors_not_pinned_when_disabled(self):
state = _make_optimizer_state([(64, 128)])
worker = _make_pinned_test_worker(False, state)

worker.move_optimizer("cpu")
for param_state in state.values():
for v in param_state.values():
assert not v.is_pinned(), (
"pageable path should not produce pinned tensors"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPinnedOptimizerEdgeCases:
"""Edge cases: empty state, scalars only, invalid device."""

@pytest.mark.parametrize("use_pinned", [False, True], ids=["pageable", "pinned"])
def test_empty_optimizer_state(self, use_pinned):
state = {}
worker = _make_pinned_test_worker(use_pinned, state)
worker.move_optimizer("cpu")
worker.move_optimizer("cuda")

@pytest.mark.parametrize("use_pinned", [False, True], ids=["pageable", "pinned"])
def test_scalars_only(self, use_pinned):
state = {0: {"step": torch.tensor(5.0, device="cuda")}}
worker = _make_pinned_test_worker(use_pinned, state)

worker.move_optimizer("cpu")
assert not state[0]["step"].is_cuda

worker.move_optimizer("cuda")
assert state[0]["step"].is_cuda

@pytest.mark.parametrize("use_pinned", [False, True], ids=["pageable", "pinned"])
def test_invalid_device_raises(self, use_pinned):
state = _make_optimizer_state([(8,)])
worker = _make_pinned_test_worker(use_pinned, state)
with pytest.raises(ValueError, match="Invalid device"):
worker.move_optimizer("tpu")

def test_pinned_buffer_grows_if_needed(self):
state_small = _make_optimizer_state([(16,)])
worker = _make_pinned_test_worker(True, state_small)
worker.move_optimizer("cpu")
small_size = worker._optimizer_pinned_buf.numel()

state_large = _make_optimizer_state([(16,), (1024, 1024)])
worker.optimizer = _FakeOptimizer(state_large)
worker.move_optimizer("cuda")
worker.move_optimizer("cpu")
large_size = worker._optimizer_pinned_buf.numel()

assert large_size > small_size, "buffer should grow for larger state"
Loading