Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Automodel-workspace/Automodel
Submodule Automodel updated 673 files
83 changes: 83 additions & 0 deletions examples/configs/recipes/llm/dapo-gemma4-31b-it-4n8g-fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
defaults: ../../grpo_math_1B.yaml
grpo:
batch_multiplier: 3
val_period: 15
max_val_samples: 960
val_batch_size: 960
use_leave_one_out_baseline: false
use_dynamic_sampling: true
reward_scaling:
enabled: true
target_min: -1.0
reward_shaping:
enabled: true
overlong_buffer_length: 384
max_response_length: 3072
loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
ratio_clip_max: 0.28
ratio_clip_c: 10
checkpointing:
checkpoint_dir: results/dapo-gemma4-31b-it-4n8g-fsdp2
save_period: 15
policy:
model_name: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/models/google/gemma-4-31B-it
tokenizer:
name: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/models/google/gemma-4-31B-it
train_global_batch_size: 32
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
logprob_chunk_size: 4096
offload_optimizer_for_logprob: true
optimizer:
kwargs:
lr: 2.0e-06
weight_decay: 0.1
scheduler:
- name: torch.optim.lr_scheduler.LinearLR
kwargs:
start_factor: 1.0e-08
end_factor: 1.0
total_iters: 10
- name: torch.optim.lr_scheduler.ConstantLR
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones:
- 10
dtensor_cfg:
activation_checkpointing: true
sequence_packing:
enabled: false
dynamic_batching:
enabled: true
make_sequence_length_divisible_by: 1
generation:
max_new_tokens: 3072
vllm_cfg:
tensor_parallel_size: 4
data:
max_input_seq_length: 2048
train:
dataset_name: DAPOMath17K
validation:
dataset_name: DAPOMathAIME2024
default:
prompt_file: null
env:
math:
num_workers: 16
math_verify_impl: dapo_math_verify
logger:
log_dir: logs/dapo-gemma4-31b-it-4n8g-fsdp2
wandb_enabled: true
tensorboard_enabled: true
wandb:
project: nemorl-gemma4
name: dapo-gemma4-31b-it-4n8g-fsdp2
cluster:
gpus_per_node: 8
num_nodes: 4
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
defaults: ../../grpo_math_1B.yaml
grpo:
val_period: 15
max_val_samples: 960
val_batch_size: 960
loss_fn:
reference_policy_kl_penalty: 0.0
use_importance_sampling_correction: true
truncated_importance_sampling_ratio: 2
checkpointing:
checkpoint_dir: results/grpo-gemma4-e2b-it-1n8g-fsdp2-automodel
policy:
model_name: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/models/google/gemma-4-E2B-it
tokenizer:
name: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_nemorl/users/shuangy/models/google/gemma-4-E2B-it
train_global_batch_size: 32
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
logprob_chunk_size: 4096
optimizer:
kwargs:
lr: 1.0e-06
weight_decay: 0.1
scheduler:
- name: torch.optim.lr_scheduler.LinearLR
kwargs:
start_factor: 1.0e-08
end_factor: 1.0
total_iters: 10
- name: torch.optim.lr_scheduler.ConstantLR
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones:
- 10
dtensor_cfg:
activation_checkpointing: true
sequence_packing:
enabled: false
dynamic_batching:
enabled: true
make_sequence_length_divisible_by: 1
generation:
max_new_tokens: 2048
data:
max_input_seq_length: 2048
train:
dataset_name: DAPOMath17K
validation:
dataset_name: DAPOMathAIME2024
default:
prompt_file: null
env:
math:
num_workers: 16
math_verify_impl: dapo_math_verify
logger:
log_dir: logs/grpo-gemma4-e2b-it-1n8g-fsdp2-automodel
wandb_enabled: true
tensorboard_enabled: true
wandb:
project: nemorl-gemma4
name: grpo-gemma4-e2b-it-1n8g-fsdp2-automodel
cluster:
gpus_per_node: 8
33 changes: 21 additions & 12 deletions nemo_rl/models/automodel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,19 +686,28 @@ def setup_model_and_optimizer(
if is_tied_lm_head:
model.tie_weights()

# Freeze visual encoder when not doing VLM training.
# Without this, the optimizer creates state entries for visual params that never
# receive gradients, causing a key mismatch when resuming from checkpoint.
# Note: visual encoder is nested under model.model (e.g. model.model.visual for
# Freeze visual/audio encoders when not doing VLM training.
# Without this, the optimizer creates state entries for visual/audio params that
# never receive gradients, causing a key mismatch when resuming from checkpoint.
# Note: encoders may be nested under model.model (e.g. model.model.visual for
# Qwen3_5MoeForConditionalGeneration), not directly on model.
visual_module = getattr(getattr(model, "model", None), "visual", None) or getattr(
model, "visual", None
)
if not is_vlm and visual_module is not None:
for param in visual_module.parameters():
param.requires_grad_(False)
if rank == 0:
print("Froze visual encoder parameters for text-only training")
if not is_vlm:
for attr in (
"visual",
"vision_tower",
"audio_tower",
"embed_vision",
"embed_audio",
):
# Handle both direct attributes and nested under model.model (FSDP wrapping)
module = getattr(model, attr, None)
if module is None:
module = getattr(getattr(model, "model", None), attr, None)
if module is not None:
for param in module.parameters():
param.requires_grad_(False)
if rank == 0:
print(f"Froze {attr} parameters for text-only training")

# CPU offload if needed
if cpu_offload:
Expand Down
38 changes: 37 additions & 1 deletion nemo_rl/models/automodel/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,27 @@
]


def _needs_kv_cache_for_shared_layers(model: nn.Module) -> bool:
"""Check if the model uses KV sharing and needs use_cache=True for correct inference.

Models with num_kv_shared_layers > 0 (e.g. Gemma4 E2B) rely on DynamicCache
to pass K/V from anchor layers to shared layers. When use_cache=False,
past_key_values is None and shared layers cannot retrieve shared K/V,
producing incorrect outputs.
"""
model_config = getattr(model, "config", None)
text_config = (
getattr(model_config, "text_config", model_config) if model_config else None
)
return getattr(text_config, "num_kv_shared_layers", 0) > 0


def model_forward(
model: nn.Module,
processed_inputs: ProcessedInputs,
is_reward_model: bool = False,
allow_flash_attn_args: bool = True,
use_cache: bool = False,
) -> torch.Tensor:
"""Perform a single forward pass through the model.

Expand All @@ -75,6 +91,9 @@ def model_forward(
processed_inputs: ProcessedInputs containing all tensors for forward pass
is_reward_model: Whether this is a reward model
allow_flash_attn_args: Whether to pass flash_attn_kwargs to model
use_cache: Whether to use KV cache. Must be True for inference on models
with KV sharing (num_kv_shared_layers > 0). Must be False for training
(backward pass / gradient checkpointing).

Returns:
torch.Tensor: Output tensor from the model (logits)
Expand All @@ -83,7 +102,7 @@ def model_forward(
input_ids=processed_inputs.input_ids,
attention_mask=processed_inputs.attention_mask,
position_ids=processed_inputs.position_ids,
use_cache=False,
use_cache=use_cache,
)

# Add flash attention kwargs if applicable
Expand All @@ -103,6 +122,13 @@ def model_forward(
if is_gemma3 and "token_type_ids" not in model_args:
model_args["token_type_ids"] = torch.zeros_like(processed_inputs.input_ids)

# Gemma 4 requires mm_token_type_ids even for text-only inputs
if getattr(getattr(model, "config", None), "model_type", None) == "gemma4":
if "mm_token_type_ids" not in model_args:
model_args["mm_token_type_ids"] = torch.zeros_like(
processed_inputs.input_ids
)

# Reward models don't support flash_attn_kwargs
if is_reward_model:
if "flash_attn_kwargs" in model_args:
Expand Down Expand Up @@ -307,12 +333,22 @@ def forward_with_post_processing_fn(
data_dict = processed_mb.data_dict
processed_inputs = processed_mb.processed_inputs

# Models with KV sharing (num_kv_shared_layers > 0, e.g. Gemma4 E2B) need
# use_cache=True so that DynamicCache is created and shared layers can
# retrieve K/V from anchor layers. Without it, shared layers fall back to
# untrained K/V projections and produce garbage.
# Note: use_cache=True is incompatible with gradient/activation checkpointing
# (DynamicCache is stateful and recomputation doubles the cache). Callers
# must disable activation_checkpointing when training KV-sharing models.
use_cache = _needs_kv_cache_for_shared_layers(model)

# Model forward pass
outputs = model_forward(
model,
processed_inputs,
is_reward_model=is_reward_model,
allow_flash_attn_args=allow_flash_attn_args,
use_cache=use_cache,
)

# Extract logits from model outputs
Expand Down
2 changes: 2 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def _patch_vllm_hermes_tool_parser_thread_safety():
arch in getattr(hf_config, "architectures", [])
for arch in (
"Gemma3ForConditionalGeneration",
"Gemma4ForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
)
Expand All @@ -524,6 +525,7 @@ def _patch_vllm_hermes_tool_parser_thread_safety():
if arch
in (
"Gemma3ForConditionalGeneration",
"Gemma4ForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
)
Expand Down
Loading
Loading