Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge
Submodule Megatron-Bridge updated 161 files
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-Bridge-workspace/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"flash-linear-attention",
"timm",
"open-clip-torch>=3.2.0",
"mlflow>=3.5.0",
"mlflow>=3.9.0",
"comet-ml>=3.50.0",
"torch>=2.6.0",
]
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM-workspace/Megatron-LM
Submodule Megatron-LM updated 225 files
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM-workspace/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# TODO(https://github.com/NVIDIA-NeMo/RL/issues/2111): upgrade to core_cu13 when we move to CUDA 13 base container
"transformer-engine[pytorch,core_cu12]",
# VCS dependency - must match pyproject.toml [tool.uv.sources]
"nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@63154570cea17f8805a7fd15cc3b8cc2919ba575",
"nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@15a851565a4ce846c04431ecb0cf09903ab4837e",
"tqdm",
"einops~=0.8",
"tensorstore~=0.1,!=0.1.46,!=0.1.72",
Expand Down
47 changes: 40 additions & 7 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
)


def fix_gpt_oss_export_transpose(key: str, weight: torch.Tensor) -> torch.Tensor:
"""Apply GPT-OSS down_proj transpose fix to the weight.

This is a workaround for the issue that the down_proj layout is not the same across different frameworks.
- HF needs [in, out] layout.
- Megatron needs [in, out] layout.
- vLLM needs [out, in] layout.
See https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/3271 for more details.
"""
if key.endswith("mlp.experts.down_proj"):
weight = weight.transpose(-2, -1).contiguous()
return weight


class VllmInternalWorkerExtension:
def init_collective(
self,
Expand Down Expand Up @@ -199,20 +213,30 @@ def update_weights_via_ipc_zmq(self) -> bool:
shape, dtype = self.state_dict_info[key] # pyrefly
if isinstance(shape, list):
shape = torch.Size(shape)

# Get the weight from the buffer
size_in_bytes = dtype.itemsize * shape.numel()
weights.append(
(
key,
buffer[offset : offset + size_in_bytes]
.view(dtype=dtype)
.view(shape),
)
weight = (
buffer[offset : offset + size_in_bytes]
.view(dtype=dtype)
.view(shape)
)
# apply gpt-oss transpose fix
if (
"GptOssForCausalLM"
in self.model_runner.vllm_config.model_config.architectures
):
weight = fix_gpt_oss_export_transpose(key, weight)
weights.append((key, weight))

# Move offset to the next weight
aligned_size = calculate_aligned_size(size_in_bytes)
offset += aligned_size

assert offset == used_bytes, (
"Offset is not equal to used bytes, usually indicate inaccurate info like keys or cached dtype in state_dict_info"
)

# Load weights into the model
from nemo_rl.models.generation.vllm.quantization import fp8

Expand Down Expand Up @@ -276,6 +300,15 @@ def _load_model_weights(weights, model_runner):
"""
from nemo_rl.models.generation.vllm.quantization import fp8

# apply gpt-oss transpose fix
if (
"GptOssForCausalLM"
in self.model_runner.vllm_config.model_config.architectures
):
for idx, (key, weight) in enumerate(weights):
weight = fix_gpt_oss_export_transpose(key, weight)
weights[idx] = (key, weight)

policy_weights, draft_weights = self._split_policy_and_draft_weights(
weights
)
Expand Down
23 changes: 0 additions & 23 deletions nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,29 +142,6 @@ def destroy_parallel_state():
except ImportError:
pass

# Reset the third global async_calls instance in base strategy module
try:
import megatron.core.dist_checkpointing.strategies.base as base_strategy
from megatron.core.dist_checkpointing.strategies.async_utils import (
AsyncCallsQueue,
)

# Clean up and reset the global async_calls in base strategy
old_call_idx = getattr(base_strategy.async_calls, "call_idx", None)
num_unfinalized = base_strategy.async_calls.get_num_unfinalized_calls()
if num_unfinalized > 0:
print(
f"[WARNING] Resetting base strategy async_calls with {num_unfinalized} unfinalized calls"
)
try:
base_strategy.async_calls.close()
except:
pass
base_strategy.async_calls = AsyncCallsQueue()
print(f"[DEBUG] Reset base strategy async_calls (old call_idx: {old_call_idx})")
except ImportError:
pass


def setup_distributed() -> None:
"""Handle NCCL settings, dtype mapping, and basic config setup."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ requires-dist = [
"flash-linear-attention",
"timm",
"open-clip-torch>=3.2.0",
"mlflow>=3.5.0",
"mlflow>=3.9.0",
"comet-ml>=3.50.0",
"torch>=2.6.0",
]
Expand Down
12 changes: 7 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading