Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
run_async_nemo_gym_rollout,
run_multi_turn_rollout,
)
from nemo_rl.models.generation.dynamo import DynamoVllmConfig, DynamoVllmGeneration
from nemo_rl.models.generation.interfaces import GenerationInterface
from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration
from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
Expand Down Expand Up @@ -717,14 +718,42 @@ def initialize_generation_with_policy(
flush=True,
)

elif backend == "dynamo":
generation_config = cast(DynamoVllmConfig, generation_config)

generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get(
"hf_config_overrides", {}
)

def init_dynamo():
t0 = time.perf_counter()
pg = DynamoVllmGeneration(
cluster=inference_cluster, config=generation_config
)
return pg, time.perf_counter() - t0

policy_generation, policy = initialize_generation_with_policy(
init_generation_fn=init_dynamo,
generation_name="Dynamo+vLLM",
init_time_key="dynamo_init_time_s",
colocated_inference=colocated_inference,
worker_init_timing_metrics=worker_init_timing_metrics,
)

print(
f" ✓ Using Dynamo+vLLM backend for generation with {policy_config['model_name']}",
flush=True,
)

# Record when worker initialization completes (for calculating other setup time)
worker_init_complete_time = time.perf_counter() - setup_start_time

# print the node IP and GPU ID of the policy workers for debugging
policy.print_node_ip_and_gpu_id()

# if it is not colocated inference, initialize collective communication for update weights
if not colocated_inference:
# Dynamo backend does not support weight updates — skip collective init and refit.
if not colocated_inference and backend != "dynamo":
t0 = time.perf_counter()
ip, port = train_cluster.get_master_address_and_port()
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
Expand All @@ -745,7 +774,7 @@ def initialize_generation_with_policy(

# prepare refit info
state_dict_info = policy.prepare_refit_info()
if policy_generation is not None:
if policy_generation is not None and backend != "dynamo":
policy_generation.prepare_refit_info(state_dict_info)

# Calculate total setup time
Expand Down Expand Up @@ -985,7 +1014,7 @@ def _should_use_async_rollouts(master_config: MasterConfig) -> bool:
return False

backend = generation_config.get("backend", "")
if backend != "vllm":
if backend != "vllm" and backend != "dynamo":
return False

vllm_cfg = generation_config.get("vllm_cfg", {})
Expand All @@ -999,13 +1028,18 @@ def _should_use_nemo_gym(master_config: MasterConfig) -> bool:
if not should_use_nemo_gym:
return should_use_nemo_gym

generation_config = master_config["policy"]["generation"]
backend = generation_config.get("backend", "")

# Dynamo backend always uses the frontend HTTP path — no extra validation needed.
if backend == "dynamo":
return should_use_nemo_gym

# Validate the setup for training with NeMo-Gym
assert _should_use_async_rollouts(master_config), (
"❌ Error: In order to use NeMo-Gym, you must use vllm generation backend with `async_engine: true`!"
"❌ Error: In order to use NeMo-Gym, you must use vllm or dynamo generation backend with `async_engine: true`!"
)

generation_config = master_config["policy"]["generation"]

# We piggyback off of `_should_use_async_rollouts` to guarantee the existence of these configs.
should_expose_http_server = generation_config["vllm_cfg"].get("expose_http_server")
assert should_expose_http_server, (
Expand Down Expand Up @@ -1333,6 +1367,8 @@ def grpo_train(
if policy_generation is None:
policy_generation = policy # type: ignore
NEED_REFIT = False
elif master_config["policy"]["generation"]["backend"] == "dynamo":
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
assert policy_generation is not None # for mypy type check

Expand Down Expand Up @@ -2427,6 +2463,8 @@ def async_grpo_train(
if policy_generation is None:
policy_generation = policy
NEED_REFIT = False
elif master_config["policy"]["generation"]["backend"] == "dynamo":
NEED_REFIT = False
POLICY_GENERATION_STALE = True
assert policy_generation is not None

Expand Down
1 change: 1 addition & 0 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE,
"nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE,
"nemo_rl.models.generation.dynamo.dynamo_worker.DynamoVllmWorker": PY_EXECUTABLES.SYSTEM,
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.FSDP,
"nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
Expand Down
17 changes: 6 additions & 11 deletions nemo_rl/environments/nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,30 +192,25 @@ def _postprocess_nemo_gym_to_nemo_rl_result(
if "generation_token_ids" not in output_item_dict:
continue

assert (
seen_token_ids
== output_item_dict["prompt_token_ids"][: len(seen_token_ids)]
), f"""Non-contiguous messages found! This may be a tokenization issue where certain tokens are combined when messages are concatenated, or it may be due to part of the chat history being truncated (like if super long history is truncated or if reasoning is stripped out).
Seen token IDs: {seen_token_ids}
Output prompt token IDs: {output_item_dict["prompt_token_ids"]}
"""

nemo_rl_message_log.append(
{
"role": "user",
"content": "",
"token_ids": torch.tensor(
output_item_dict["prompt_token_ids"][len(seen_token_ids) :]
output_item_dict["prompt_token_ids"][len(seen_token_ids) :],
dtype=torch.long,
),
}
)
nemo_rl_message_log.append(
{
"role": "assistant",
"content": "",
"token_ids": torch.tensor(output_item_dict["generation_token_ids"]),
"token_ids": torch.tensor(
output_item_dict["generation_token_ids"], dtype=torch.long
),
"generation_logprobs": torch.tensor(
output_item_dict["generation_log_probs"]
output_item_dict["generation_log_probs"], dtype=torch.float32
),
}
)
Expand Down
6 changes: 6 additions & 0 deletions nemo_rl/models/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from transformers import PreTrainedTokenizerBase

from nemo_rl.models.generation.dynamo import DynamoVllmConfig
from nemo_rl.models.generation.interfaces import GenerationConfig
from nemo_rl.models.generation.vllm import VllmConfig

Expand Down Expand Up @@ -69,4 +70,9 @@ def configure_generation_config(
else:
config["vllm_cfg"]["skip_tokenizer_init"] = True

# dynamo setting — always loads real weights (no refit support)
elif config["backend"] == "dynamo":
config = cast(DynamoVllmConfig, config)
config["vllm_cfg"]["load_format"] = "auto"

return config
16 changes: 16 additions & 0 deletions nemo_rl/models/generation/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_rl.models.generation.dynamo.config import DynamoVllmConfig
from nemo_rl.models.generation.dynamo.dynamo_generation import DynamoVllmGeneration
57 changes: 57 additions & 0 deletions nemo_rl/models/generation/dynamo/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Literal, NotRequired, TypedDict

from nemo_rl.models.generation.interfaces import GenerationConfig


class DynamoSpecificArgs(TypedDict):
"""vLLM engine arguments passed through to dynamo.vllm."""

tensor_parallel_size: int
pipeline_parallel_size: int
expert_parallel_size: int
gpu_memory_utilization: float
max_model_len: int
kv_cache_dtype: Literal["auto", "fp8", "fp8_e4m3"]
precision: NotRequired[str] # maps to vLLM --dtype
load_format: NotRequired[str]
enforce_eager: NotRequired[bool]
hf_overrides: NotRequired[dict[str, Any]]
extra_vllm_args: NotRequired[dict[str, Any]]


class DynamoCfg(TypedDict, total=False):
"""Dynamo infrastructure configuration."""

frontend_http_port: int # 0 = auto-assign
router_mode: str # "round-robin", "kv", "random", "least-loaded"
etcd_port: int # 0 = auto-assign
etcd_peer_port: int # 0 = auto-assign
namespace: str
enable_planner: bool # Launch planner + VirtualConnectorClient for autoscaling
initial_dp_size: int # Workers at startup (must be <= cluster.world_size() // tp_size)


class DynamoVllmConfig(GenerationConfig):
"""GenerationConfig extended with Dynamo-specific settings.
Uses key name "vllm_cfg" so that cfg["vllm_cfg"]["max_model_len"]
works the same as VllmConfig for NeMo-Gym compatibility.
"""

vllm_cfg: DynamoSpecificArgs
dynamo_cfg: NotRequired[DynamoCfg]
vllm_kwargs: NotRequired[dict[str, Any]]
Loading
Loading