diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-tulu3-1n8g-pp2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-tulu3-1n8g-pp2.yaml new file mode 100644 index 0000000000..dda65c08b8 --- /dev/null +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-tulu3-1n8g-pp2.yaml @@ -0,0 +1,14 @@ +defaults: ./dpo-llama3.1-8b-tulu3-1n8g-fsdp2tp1.yaml +policy: + train_micro_batch_size: 2 + dtensor_cfg: + _v2: true + pipeline_parallel_size: 2 + automodel_kwargs: + force_hf: true + pipeline_config: + _target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig + pp_schedule: 1F1B + pp_microbatch_size: 1 +checkpointing: + checkpoint_dir: results/dpo-llama3.1-8b-tulu3-1n8g-pp2 diff --git a/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack.yaml b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack.yaml new file mode 100644 index 0000000000..f63dda5c7c --- /dev/null +++ b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack.yaml @@ -0,0 +1,10 @@ +defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml +policy: + sequence_packing: + enabled: true + dynamic_batching: + enabled: false + dtensor_cfg: + context_parallel_size: 2 +checkpointing: + checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack diff --git a/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4.yaml b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4.yaml new file mode 100644 index 0000000000..f2d7d9827c --- /dev/null +++ b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4.yaml @@ -0,0 +1,9 @@ +defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml +policy: + dynamic_batching: + enabled: false + dtensor_cfg: + context_parallel_size: 2 + expert_parallel_size: 4 +checkpointing: + checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-cp2ep4 diff --git a/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack.yaml b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack.yaml new file mode 100644 index 0000000000..e6526fe64b --- /dev/null +++ b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack.yaml @@ -0,0 +1,8 @@ +defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml +policy: + sequence_packing: + enabled: true + dynamic_batching: + enabled: false +checkpointing: + checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack diff --git a/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.yaml b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.yaml new file mode 100644 index 0000000000..cd117d2086 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.yaml @@ -0,0 +1,10 @@ +defaults: ./grpo-moonlight-16b-automodel-1n8g-pp2ep4.yaml +policy: + max_total_sequence_length: 2048 + sequence_packing: + enabled: true + train_mb_tokens: 2048 +data: + max_input_seq_length: 2048 +checkpointing: + checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack diff --git a/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4.yaml b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4.yaml new file mode 100644 index 0000000000..553658576f --- /dev/null +++ b/examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4.yaml @@ -0,0 +1,15 @@ +defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml +policy: + train_micro_batch_size: 4 + dtensor_cfg: + pipeline_parallel_size: 2 + expert_parallel_size: 4 + automodel_kwargs: + pipeline_config: + _target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig + pp_schedule: Interleaved1F1B + pp_microbatch_size: 2 + dynamic_batching: + enabled: false +checkpointing: + checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-pp2ep4 diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-cp2ep4-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-cp2ep4-seqpack-automodel.yaml new file mode 100644 index 0000000000..c7156f08e0 --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-cp2ep4-seqpack-automodel.yaml @@ -0,0 +1,8 @@ +defaults: ./sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml +policy: + dtensor_cfg: + context_parallel_size: 2 + sequence_packing: + enabled: true +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-cp2ep4-seqpack-automodel diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel.yaml new file mode 100644 index 0000000000..5a332e0b98 --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel.yaml @@ -0,0 +1,6 @@ +defaults: ./sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml +policy: + sequence_packing: + enabled: true +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2cp2ep2-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2cp2ep2-seqpack-automodel.yaml new file mode 100644 index 0000000000..a7c7a4a93e --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2cp2ep2-seqpack-automodel.yaml @@ -0,0 +1,10 @@ +defaults: ./sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml +sft: + val_at_start: false +policy: + sequence_packing: + enabled: true + dtensor_cfg: + context_parallel_size: 2 +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-pp2cp2ep2-seqpack-automodel diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml new file mode 100644 index 0000000000..8fe8eaf742 --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml @@ -0,0 +1,14 @@ +defaults: ./sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml +sft: + val_micro_batch_size: 4 +policy: + dtensor_cfg: + pipeline_parallel_size: 2 + expert_parallel_size: 4 + automodel_kwargs: + pipeline_config: + _target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig + pp_schedule: 1f1b + pp_microbatch_size: 4 +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-pp2ep4-automodel diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.yaml new file mode 100644 index 0000000000..58b0a73c32 --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.yaml @@ -0,0 +1,8 @@ +defaults: ./sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml +sft: + val_at_start: false +policy: + sequence_packing: + enabled: true +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel diff --git a/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-cp2ep4-automodel.yaml b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-cp2ep4-automodel.yaml new file mode 100644 index 0000000000..eed18934f2 --- /dev/null +++ b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-cp2ep4-automodel.yaml @@ -0,0 +1,6 @@ +defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml +policy: + sequence_packing: + enabled: false + dtensor_cfg: + context_parallel_size: 2 diff --git a/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-cp2ep4-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-cp2ep4-seqpack-automodel.yaml new file mode 100644 index 0000000000..193c992877 --- /dev/null +++ b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-cp2ep4-seqpack-automodel.yaml @@ -0,0 +1,4 @@ +defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml +policy: + dtensor_cfg: + context_parallel_size: 2 diff --git a/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-fsdp8ep8-automodel.yaml b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-fsdp8ep8-automodel.yaml new file mode 100644 index 0000000000..8644bbc753 --- /dev/null +++ b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-fsdp8ep8-automodel.yaml @@ -0,0 +1,4 @@ +defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml +policy: + sequence_packing: + enabled: false diff --git a/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml new file mode 100644 index 0000000000..b120783916 --- /dev/null +++ b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml @@ -0,0 +1,27 @@ +defaults: ../../sft.yaml +policy: + model_name: moonshotai/Moonlight-16B-A3B-Instruct + train_global_batch_size: 128 + train_micro_batch_size: 8 + max_total_sequence_length: 512 + sequence_packing: + enabled: true + dtensor_cfg: + expert_parallel_size: 8 + automodel_kwargs: + backend: + _target_: nemo_automodel.components.models.common.utils.BackendConfig + attn: te + linear: te + rms_norm: te + enable_deepep: true + rope_fusion: false + enable_hf_state_dict_adapter: true + enable_fsdp_optimizations: false + make_sequence_length_divisible_by: 4 +sft: + val_at_start: false +checkpointing: + enabled: false +cluster: + gpus_per_node: 8 diff --git a/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-pp2cp2ep2-seqpack-automodel.yaml b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-pp2cp2ep2-seqpack-automodel.yaml new file mode 100644 index 0000000000..5567a5b6b7 --- /dev/null +++ b/examples/configs/recipes/llm/sft-moonlight-16b-1n8g-pp2cp2ep2-seqpack-automodel.yaml @@ -0,0 +1,11 @@ +defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml +policy: + dtensor_cfg: + pipeline_parallel_size: 2 + context_parallel_size: 2 + expert_parallel_size: 4 + automodel_kwargs: + pipeline_config: + _target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig + pp_schedule: 1f1b + pp_microbatch_size: 4 diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index dd46ff7a27..e32259548c 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -30,6 +30,68 @@ ) +class _AllGatherCPNoReduceBackward(torch.autograd.Function): + """Like AllGatherCPTensor but WITHOUT all-reduce in backward. + + Used when each CP rank computes loss on allgathered (identical) logprobs. + In this case, all ranks produce the same gradient, so all-reduce would + double it. This version just slices the gradient to the local chunk. + """ + + @staticmethod + def forward(ctx, tensor, cp_group, seq_dim=1): + cp_size = torch.distributed.get_world_size(cp_group) + cp_rank_chunks = [torch.empty_like(tensor) for _ in range(cp_size)] + torch.distributed.all_gather(cp_rank_chunks, tensor, group=cp_group) + + # Undo dual-chunk-swap ordering + tensor_chunks = [] + for chunk in cp_rank_chunks: + tensor_chunks.extend(torch.chunk(chunk, chunks=2, dim=seq_dim)) + chunk_indices = [] + for r in range(cp_size): + chunk_indices.append(r) + chunk_indices.append(2 * cp_size - r - 1) + pairs = sorted(zip(tensor_chunks, chunk_indices), key=lambda t: t[1]) + ret = torch.cat([c for c, _ in pairs], dim=seq_dim) + + ctx.seq_dim = seq_dim + ctx.cp_group = cp_group + return ret + + @staticmethod + def backward(ctx, grad_output): + cp_size = torch.distributed.get_world_size(ctx.cp_group) + cp_rank = torch.distributed.get_rank(ctx.cp_group) + seq_dim = ctx.seq_dim + # NO all-reduce — just select this rank's chunk + grad_output = grad_output.view( + *grad_output.shape[:seq_dim], + 2 * cp_size, + grad_output.shape[seq_dim] // (2 * cp_size), + *grad_output.shape[seq_dim + 1 :], + ) + index = torch.tensor( + [cp_rank, 2 * cp_size - cp_rank - 1], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + grad_input = grad_output.index_select(seq_dim, index) + grad_input = grad_input.view( + *grad_input.shape[:seq_dim], -1, *grad_input.shape[seq_dim + 2 :] + ) + return grad_input, None, None + + +_SELF_TP_GROUP: Optional[torch.distributed.ProcessGroup] = None + + +def _get_self_tp_group() -> torch.distributed.ProcessGroup: + """Cached single-rank TP group for CP-only logprob computation.""" + global _SELF_TP_GROUP + if _SELF_TP_GROUP is None: + _SELF_TP_GROUP = torch.distributed.new_group([torch.distributed.get_rank()]) + return _SELF_TP_GROUP + + @torch.no_grad() def _compute_distributed_log_softmax( vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup @@ -941,6 +1003,7 @@ def from_parallel_logits_to_logprobs( tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, cp_group: Optional[torch.distributed.ProcessGroup] = None, + cp_allgather_no_reduce: bool = False, chunk_size: Optional[int] = None, sampling_params: Optional[TrainingSamplingParams] = None, ) -> torch.Tensor: @@ -1020,10 +1083,12 @@ def from_parallel_logits_to_logprobs( ).contiguous() if cp_size > 1: - # we need to gather the logits by context parallelism - logprobs = allgather_cp_sharded_tensor( - logprobs, cp_group, seq_dim=1 - ) # , unpadded_seqlen=target.shape[1]) + if cp_allgather_no_reduce: + # No all-reduce in backward — used when each CP rank computes + # loss on allgathered logprobs (avoids 2x gradient). + logprobs = _AllGatherCPNoReduceBackward.apply(logprobs, cp_group, 1) + else: + logprobs = allgather_cp_sharded_tensor(logprobs, cp_group, seq_dim=1) if pad_len > 0: logprobs = logprobs[:, :-pad_len] @@ -1394,6 +1459,27 @@ def get_next_token_logprobs_from_logits( # slice off to the correct length to remove potential CP padding logprobs = logprobs[:, : input_ids.shape[1] - 1] + elif context_parallel_group is not None and not isinstance( + next_token_logits, torch.distributed.tensor.DTensor + ): + # CP-only path (no TP): automodel backend with full vocab per rank. + # TE's thd_get_partitioned_indices uses dual-chunk-swap (same as + # _get_tokens_on_this_cp_rank), so from_parallel_logits_to_logprobs works. + tp_group = _get_self_tp_group() + vocab_size = next_token_logits.shape[-1] + logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + input_ids, + vocab_start_index=0, + vocab_end_index=vocab_size, + tp_group=tp_group, + inference_only=False, + cp_group=context_parallel_group, + cp_allgather_no_reduce=True, # avoid 2x gradient from allgather backward + sampling_params=sampling_params, + ) + logprobs = logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): logprobs = get_logprobs_from_vocab_parallel_logits( next_token_logits, diff --git a/nemo_rl/models/automodel/checkpoint.py b/nemo_rl/models/automodel/checkpoint.py index 00ac1923a7..62b49e5eff 100644 --- a/nemo_rl/models/automodel/checkpoint.py +++ b/nemo_rl/models/automodel/checkpoint.py @@ -54,6 +54,7 @@ def __init__( dp_mesh: DeviceMesh, tp_mesh: DeviceMesh, moe_mesh: Optional[DeviceMesh] = None, + pp_mesh: Optional[DeviceMesh] = None, ): """Initialize the AutomodelCheckpointManager. @@ -61,12 +62,14 @@ def __init__( dp_mesh: The data parallel device mesh. tp_mesh: The tensor parallel device mesh. moe_mesh: Optional MoE device mesh. + pp_mesh: Optional pipeline parallel device mesh. """ self.checkpointer: Optional[Checkpointer] = None self.checkpoint_config: Optional[AutomodelCheckpointingConfig] = None self.dp_mesh = dp_mesh self.tp_mesh = tp_mesh self.moe_mesh = moe_mesh + self.pp_mesh = pp_mesh def _get_dp_rank(self) -> int: """Get the data parallel rank.""" @@ -98,7 +101,11 @@ def init_checkpointer( dp_rank = self._get_dp_rank() tp_rank = self._get_tp_rank() - pp_rank = 0 + pp_rank = ( + torch.distributed.get_rank(self.pp_mesh.get_group()) + if self.pp_mesh is not None + else 0 + ) # Initialize a base config with sensible defaults base_cfg = AutomodelCheckpointingConfig( @@ -182,11 +189,14 @@ def _rebuild_checkpointer_addons(self) -> None: def save_checkpoint( self, - model: nn.Module, + model: nn.Module | list[nn.Module], weights_path: str, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: Optional[torch.optim.Optimizer | list[torch.optim.Optimizer]] = None, optimizer_path: Optional[str] = None, - scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + scheduler: Optional[ + torch.optim.lr_scheduler.LRScheduler + | list[torch.optim.lr_scheduler.LRScheduler] + ] = None, tokenizer: Optional[AutoTokenizer] = None, tokenizer_path: Optional[str] = None, checkpointing_cfg: Optional[CheckpointingConfig] = None, @@ -197,12 +207,16 @@ def save_checkpoint( The optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + For pipeline parallel, model/optimizer/scheduler can be lists (one per PP stage). + The automodel Checkpointer's ModelState/OptimizerState wrappers handle list inputs + by merging state dicts from all parts. + Args: - model: The model to save. + model: The model to save, or list of model parts for PP. weights_path: Path to save model weights. - optimizer: Optional optimizer to save. + optimizer: Optional optimizer (or list for PP) to save. optimizer_path: Optional path to save optimizer state. - scheduler: Optional learning rate scheduler. + scheduler: Optional learning rate scheduler (or list for PP). tokenizer: Optional tokenizer to save with the checkpoint. tokenizer_path: Optional path to save tokenizer separately. checkpointing_cfg: Checkpointing configuration. @@ -267,20 +281,25 @@ def save_checkpoint( def load_checkpoint( self, - model: nn.Module, + model: nn.Module | list[nn.Module], weights_path: str, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: Optional[torch.optim.Optimizer | list[torch.optim.Optimizer]] = None, optimizer_path: Optional[str] = None, - scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + scheduler: Optional[ + torch.optim.lr_scheduler.LRScheduler + | list[torch.optim.lr_scheduler.LRScheduler] + ] = None, ) -> None: """Load a checkpoint into the model using Automodel Checkpointer. + For pipeline parallel, model/optimizer/scheduler can be lists (one per PP stage). + Args: - model: The model to load weights into. + model: The model to load weights into, or list of model parts for PP. weights_path: Path to the checkpoint weights. - optimizer: Optional optimizer to load state into. + optimizer: Optional optimizer (or list for PP) to load state into. optimizer_path: Optional path to optimizer checkpoint. - scheduler: Optional learning rate scheduler. + scheduler: Optional learning rate scheduler (or list for PP). """ print(f"Loading weights from {weights_path}") assert self.checkpointer is not None, ( diff --git a/nemo_rl/models/automodel/config.py b/nemo_rl/models/automodel/config.py index ba959aa334..2244033f2c 100644 --- a/nemo_rl/models/automodel/config.py +++ b/nemo_rl/models/automodel/config.py @@ -36,6 +36,8 @@ class DistributedContext(NamedTuple): dp_size: int tp_size: int cp_size: int + pp_size: int + pp_mesh: Any # Optional[DeviceMesh], None when pp_size==1 class RuntimeConfig(NamedTuple): @@ -80,9 +82,9 @@ class ModelAndOptimizerState(NamedTuple): optimizer, scheduler, and metadata about the model type and configuration. """ - model: torch.nn.Module - optimizer: Optional[torch.optim.Optimizer] - scheduler: Optional[Any] + model: Any # nn.Module or ModelHandle for PP + optimizers: Optional[list[torch.optim.Optimizer]] + schedulers: Optional[list[Any]] is_hf_model: bool is_moe_model: bool is_reward_model: bool diff --git a/nemo_rl/models/automodel/data.py b/nemo_rl/models/automodel/data.py index 98eed48d4f..b85c7df0a4 100644 --- a/nemo_rl/models/automodel/data.py +++ b/nemo_rl/models/automodel/data.py @@ -55,6 +55,9 @@ class ProcessedInputs: cp_buffers: list[torch.Tensor] = field(default_factory=list) seq_index: Optional[torch.Tensor] = None + # THD batch for CP+seqpack path (bypasses FA2 and DTensor CP) + thd_batch: Optional["THDBatch"] = None + @property def has_context_parallel(self) -> bool: """Check if context parallel is enabled.""" @@ -100,6 +103,8 @@ def make_processed_microbatch_iterator( tokenizer: AutoTokenizer, cfg: dict[str, Any], cp_size: int, + cp_mesh: Any = None, + device_mesh: Any = None, ) -> Iterator[ProcessedMicrobatch]: """Wrap a raw microbatch iterator to yield processed microbatches. @@ -112,6 +117,8 @@ def make_processed_microbatch_iterator( tokenizer: Tokenizer for processing cfg: Configuration dictionary (enable_seq_packing is inferred from cfg["sequence_packing"]["enabled"]) cp_size: Context parallel size + cp_mesh: Optional CP device mesh (needed for seq_packing + CP) + device_mesh: Full device mesh with "cp" dim (needed for CP+seqpack THD path) Yields: ProcessedMicrobatch objects containing processed tensors ready for model forward @@ -131,6 +138,8 @@ def make_processed_microbatch_iterator( enable_seq_packing, cfg, cp_size, + cp_mesh=cp_mesh, + device_mesh=device_mesh, ) yield ProcessedMicrobatch( @@ -148,6 +157,8 @@ def get_microbatch_iterator( dp_mesh: Any, # noqa: ARG001 tokenizer: AutoTokenizer, cp_size: int = 1, + cp_mesh: Any = None, + device_mesh: Any = None, ) -> tuple[Iterator[ProcessedMicrobatch], int]: """Create processed microbatch iterator based on batching strategy. @@ -193,6 +204,8 @@ def get_microbatch_iterator( tokenizer, cfg, cp_size, + cp_mesh=cp_mesh, + device_mesh=device_mesh, ) return processed_iterator, iterator_len @@ -203,6 +216,8 @@ def process_microbatch( enable_seq_packing: bool, cfg: dict[str, Any], cp_size: int, + cp_mesh: Any = None, + device_mesh: Any = None, ) -> ProcessedInputs: """Process a microbatch and prepare inputs for model forward. @@ -212,13 +227,39 @@ def process_microbatch( enable_seq_packing: Whether sequence packing is enabled cfg: Configuration dictionary cp_size: Context parallel size + cp_mesh: CP device mesh (for CP+seqpack THD path) + device_mesh: Full device mesh with "cp" dim (for CP+seqpack THD path) Returns: ProcessedInputs containing all tensors and metadata for forward pass """ input_ids = mb.get("input_ids").cuda() - if enable_seq_packing: + if enable_seq_packing and cp_size > 1: + # CP+seqpack: use THD path with TE CP sharding. + # pack_for_thd handles CP-padding, THD conversion, and CP sharding. + # Pass token_mask so prompt tokens get labels=-100 in packed format. + token_mask = mb.get("token_mask", None) + thd_result = pack_for_thd( + input_ids=input_ids, + input_lengths=mb["input_lengths"], + packed_sequence_size=[len(mb["input_lengths"])], + padding_value=tokenizer.eos_token_id, + min_seq_len=cfg["sequence_packing"]["train_mb_tokens"], + cp_size=cp_size, + cp_mesh=cp_mesh, + device_mesh=device_mesh, + token_mask=token_mask, + ) + return ProcessedInputs( + input_ids=thd_result.input_ids, + position_ids=thd_result.position_ids, + attention_mask=None, + flash_attn_kwargs={}, + seq_len=thd_result.input_ids.shape[0], + thd_batch=thd_result, + ) + elif enable_seq_packing: input_ids, position_ids, _ = pack_sequences( input_ids=input_ids, input_lengths=mb["input_lengths"], @@ -348,6 +389,270 @@ def process_global_batch( } +@dataclass +class THDBatch: + """Packed sequence batch in THD-ready format for custom automodel models. + + Custom models (gpt-oss, qwen3-moe) with TE backend expect individual kwargs + (cu_seqlens, qkv_format) instead of HF's bundled flash_attn_kwargs. + """ + + input_ids: torch.Tensor # [packed_len] (1D THD) or [n_rows, packed_len] (2D for PP) + position_ids: torch.Tensor # same shape as input_ids + labels: torch.Tensor # same shape as input_ids + cu_seqlens: torch.Tensor # [num_seqs+1] or [n_rows, max_seqs+1] for PP + cu_seqlens_per_row: list # per-row clean cu_seqlens from actual lengths + cu_seqlens_padded_per_row: list # per-row clean cu_seqlens from CP-padded lengths + n_packed_rows: int # number of packed rows (= n_microbatches for PP) + + # CP-specific fields (set by pack_for_thd when cp_size > 1) + cp_size: int = 1 + cp_rank: int = 0 + max_seqlen: Optional[torch.Tensor] = None + + def to_model_kwargs(self, device: torch.device) -> dict[str, Any]: + """Build kwargs dict to pass to schedule.step() or model forward. + + For a single row (non-PP or PP with n_microbatches=1), tensors are 1D + (THD format). For multiple rows (PP with n_microbatches > 1), tensors + are 2D [n_rows, packed_len] so the PP schedule can split along dim 0. + The _thd_squeeze_hook on model parts squeezes dim 0 after splitting. + """ + result = { + "input_ids": self.input_ids.to(device), + "labels": self.labels.to(device), + "position_ids": self.position_ids.to(device), + "cu_seqlens": self.cu_seqlens.to(dtype=torch.int32, device=device), + "qkv_format": "thd", + } + # CP size/rank are already configured on the model's attention modules + # via apply_cp(). Don't pass them as kwargs — it can confuse TE backends. + if self.max_seqlen is not None: + result["max_seqlen"] = self.max_seqlen.to(device) + return result + + +def install_thd_squeeze_hook(model_parts: list) -> list: + """Install forward pre-hooks that squeeze cu_seqlens for PP microbatches. + + When the PP schedule splits [n_rows, packed_len] along dim 0, cu_seqlens + becomes [1, max_seqs+1]. Custom models expect 1D cu_seqlens, so the hook + squeezes dim 0. Input_ids stays [1, packed_len] — all custom models handle + this via their internal THD unsqueeze logic. + + Returns list of hook handles for removal. + """ + handles = [] + for part in model_parts: + def _squeeze_hook(module, args, kwargs): + if "cu_seqlens" not in kwargs: + return args, kwargs + kwargs = dict(kwargs) + if kwargs["cu_seqlens"].ndim == 2 and kwargs["cu_seqlens"].shape[0] == 1: + kwargs["cu_seqlens"] = kwargs["cu_seqlens"].squeeze(0) + return args, kwargs + + h = part.register_forward_pre_hook(_squeeze_hook, with_kwargs=True) + handles.append(h) + return handles + + +def _cp_pad_length(length: int, cp_size: int) -> int: + """Pad a sequence length to the nearest multiple of 2*cp_size.""" + divisor = 2 * cp_size + return ((length + divisor - 1) // divisor) * divisor + + +def pack_for_thd( + input_ids: torch.Tensor, + input_lengths: torch.Tensor, + packed_sequence_size: list[int], + padding_value: int, + min_seq_len: int = 0, + num_chunks: int = 1, + cp_size: int = 1, + cp_mesh: Any = None, + device_mesh: Any = None, + token_mask: Optional[torch.Tensor] = None, +) -> THDBatch: + """Pack sequences into THD-format batch, optionally with CP sharding. + + For CP > 1, individual sequences are padded to multiples of ``2*cp_size`` + before packing, and ``make_cp_batch_and_ctx`` shards tokens across CP ranks. + For PP, ``num_chunks`` controls how many microbatch rows are produced. + + Args: + input_ids: [num_sequences, max_seq_len] raw input + input_lengths: [num_sequences] actual lengths + packed_sequence_size: How many sequences per packed row, e.g. [4, 4] + padding_value: Pad token id + min_seq_len: Minimum packed row length + num_chunks: Number of THD chunks (= n_microbatches for PP). Default 1. + cp_size: Context parallel size. Default 1. + cp_mesh: CP device mesh (required when cp_size > 1). + device_mesh: Full device mesh with "cp" dim (required when cp_size > 1). + + Returns: + THDBatch with THD-format data, optionally CP-sharded. + """ + from nemo_automodel.components.distributed.thd_utils import ( + split_batch_into_thd_chunks, + ) + + # When CP > 1, pad each individual sequence to 2*cp_size multiple. + # This is done by adjusting input_lengths; pack_sequences will trim + # each sequence to input_lengths[i] tokens, so we use the CP-padded + # lengths as the actual lengths for packing. + if cp_size > 1: + cp_padded_lengths = torch.tensor( + [_cp_pad_length(l.item(), cp_size) for l in input_lengths], + dtype=input_lengths.dtype, + device=input_lengths.device, + ) + else: + cp_padded_lengths = input_lengths + + packed_ids, packed_pos_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=cp_padded_lengths, + packed_sequence_size=packed_sequence_size, + padding_value=padding_value, + return_attention_mask=False, + min_seq_len=min_seq_len, + ) + n_rows = len(packed_sequence_size) + row_len = packed_ids.shape[1] + + # Build seq_lens (actual) and seq_lens_padded (CP-padded per seq, last + # seq absorbs remaining space to fill row_len — matching automodel's + # packed_sequence_thd_collater convention). + seq_lens_list = [] + seq_lens_padded_list = [] + cu_seqlens_per_row = [] + cu_seqlens_padded_per_row = [] + seq_idx = 0 + for row_size in packed_sequence_size: + row_actual = input_lengths[seq_idx : seq_idx + row_size].clone() + row_cp_padded = cp_padded_lengths[seq_idx : seq_idx + row_size].clone() + + # Last sequence absorbs remaining space to fill the row + row_padded_for_thd = row_cp_padded.clone() + total_cp_padded = int(row_cp_padded.sum().item()) + remaining = row_len - total_cp_padded + if remaining > 0: + row_padded_for_thd[-1] = row_padded_for_thd[-1] + remaining + + seq_lens_list.append(row_actual) + seq_lens_padded_list.append(row_padded_for_thd) + + # Clean cu_seqlens from actual lengths (for data slicing in loss wrapper) + cu = torch.nn.functional.pad( + row_actual.to(torch.int32).cumsum(dim=0), (1, 0) + ) + cu_seqlens_per_row.append(cu) + + # Clean cu_seqlens from CP-padded lengths (for CP logit slicing) + cu_padded = torch.nn.functional.pad( + row_padded_for_thd.to(torch.int32).cumsum(dim=0), (1, 0) + ) + cu_seqlens_padded_per_row.append(cu_padded) + seq_idx += row_size + + # Pad to uniform number of sequences per row + max_seqs = max(len(s) for s in seq_lens_list) + seq_lens = torch.stack([ + torch.nn.functional.pad(s, (0, max_seqs - len(s)), value=-1000) + for s in seq_lens_list + ]) + seq_lens_padded = torch.stack([ + torch.nn.functional.pad(s, (0, max_seqs - len(s)), value=-1000) + for s in seq_lens_padded_list + ]) + + # Build labels with -100 for non-trainable positions: + # (a) CP padding between sequences (beyond actual but within CP-padded) + # (b) End-of-row padding (beyond total valid tokens) + # (c) Prompt tokens (where token_mask == 0, if provided) + # This is critical for the CP loss path which uses cross_entropy directly + # on all tokens (not bounded by cu_seqlens like SequencePackingLossWrapper). + # + # Also pack token_mask using the same CP-padded lengths so we can + # mask prompt tokens in the packed format. + if token_mask is not None: + # token_mask: [batch_size, seq_len] with 0=prompt, 1=response. + # Pack using the same CP-padded lengths to align with packed_ids. + packed_mask, _, _ = pack_sequences( + input_ids=token_mask.float(), + input_lengths=cp_padded_lengths, + packed_sequence_size=packed_sequence_size, + padding_value=0, + return_attention_mask=False, + min_seq_len=min_seq_len, + ) + + labels = packed_ids.clone() + for row_idx in range(n_rows): + row_actual = seq_lens_list[row_idx] + row_cp_padded = cp_padded_lengths[ + sum(packed_sequence_size[:row_idx]) : sum(packed_sequence_size[:row_idx + 1]) + ] + # Mark padding between each sequence's actual and CP-padded boundary + pos = 0 + for seq_i in range(len(row_actual)): + actual = int(row_actual[seq_i].item()) + padded = int(row_cp_padded[seq_i].item()) + if actual < padded: + labels[row_idx, pos + actual : pos + padded] = -100 + pos += padded + # Mark end-of-row padding + if pos < row_len: + labels[row_idx, pos:] = -100 + + # Mask prompt tokens using token_mask + if token_mask is not None: + labels[packed_mask < 0.5] = -100 + + thd_input = { + "input_ids": packed_ids, + "labels": labels, + "position_ids": packed_pos_ids, + "seq_lens": seq_lens, + "seq_lens_padded": seq_lens_padded, + } + + if cp_size > 1: + # Use automodel's make_cp_batch_and_ctx for combined THD + CP sharding. + from nemo_automodel.components.distributed.cp_utils import ( + make_cp_batch_and_ctx, + ) + _, thd_batch = make_cp_batch_and_ctx( + device_mesh, + thd_input, + use_te=True, + padding_token_id=padding_value, + num_chunks=num_chunks, + ) + else: + thd_batch = split_batch_into_thd_chunks( + thd_input, + num_chunks=num_chunks, + padding_token_id=padding_value, + ) + + return THDBatch( + input_ids=thd_batch["input_ids"], + position_ids=thd_batch["position_ids"], + labels=thd_batch["labels"], + cu_seqlens=thd_batch["cu_seqlens"], + cu_seqlens_per_row=cu_seqlens_per_row, + cu_seqlens_padded_per_row=cu_seqlens_padded_per_row, + n_packed_rows=n_rows, + cp_size=thd_batch.get("cp_size", 1), + cp_rank=thd_batch.get("cp_rank", 0), + max_seqlen=thd_batch.get("max_seqlen"), + ) + + def check_sequence_dim(data: BatchedDataDict[Any]) -> Tuple[int, int]: """Check and validate sequence dimension across all tensors. diff --git a/nemo_rl/models/automodel/model_handle.py b/nemo_rl/models/automodel/model_handle.py new file mode 100644 index 0000000000..90ba5d1e0b --- /dev/null +++ b/nemo_rl/models/automodel/model_handle.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified model interface for single nn.Module and AutoPipeline (PP). + +Normalizes the API so callers never need ``if pp_enabled`` checks for basic +model operations (state_dict, eval, train, parameters, buffers, config, etc.). +""" + +from __future__ import annotations + +from typing import Any, Iterator + +import torch +from torch import nn + + +class ModelHandle: + """Thin wrapper that normalizes nn.Module vs AutoPipeline interface. + + With pipeline parallelism, ``from_pretrained`` returns an ``AutoPipeline`` + object instead of an ``nn.Module``. AutoPipeline lacks standard nn.Module + methods (``state_dict``, ``eval``, ``train``, ``parameters``, etc.). This + wrapper provides a unified API so that downstream code can treat both cases + identically. + + Usage:: + + handle = ModelHandle(model) # model is nn.Module or AutoPipeline + handle.eval() # works for both + for k, v in handle.state_dict_items(): + ... + handle.config.pad_token_id # works for both + """ + + def __init__(self, model: Any) -> None: + self._model = model + self._pp_enabled = hasattr(model, "parts") and hasattr(model, "info") + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def raw(self) -> Any: + """Access the underlying model/AutoPipeline directly. + + Use for PP-specific APIs like ``raw.info.schedule``, + ``raw.update_seq_len()``, ``raw.pp_batch_size``, etc. + """ + return self._model + + @property + def pp_enabled(self) -> bool: + return self._pp_enabled + + @property + def parts(self) -> list[nn.Module]: + """List of nn.Module model parts (always a list, even for non-PP).""" + if self._pp_enabled: + return list(self._model.parts) + return [self._model] + + @property + def config(self) -> Any: + """Model config (HF AutoConfig or similar).""" + if self._pp_enabled: + return self._model.parts[0].config + return self._model.config + + @property + def has_first_stage(self) -> bool: + if self._pp_enabled: + return self._model.info.has_first_stage + return True + + @property + def has_last_stage(self) -> bool: + if self._pp_enabled: + return self._model.info.has_last_stage + return True + + # ------------------------------------------------------------------ + # nn.Module-like API + # ------------------------------------------------------------------ + + def state_dict(self) -> dict[str, Any]: + """Merged state dict across all parts.""" + if self._pp_enabled: + state: dict[str, Any] = {} + for part in self._model.parts: + state.update(part.state_dict()) + return state + return self._model.state_dict() + + def state_dict_items(self) -> Iterator[tuple[str, torch.Tensor]]: + """Iterate ``(key, tensor)`` pairs of state dict (memory-efficient).""" + for part in self.parts: + yield from part.state_dict().items() + + def state_dict_keys(self) -> list[str]: + """All state dict keys across all parts.""" + keys: list[str] = [] + for part in self.parts: + keys.extend(part.state_dict().keys()) + return keys + + def eval(self) -> "ModelHandle": + for part in self.parts: + part.eval() + return self + + def train(self, mode: bool = True) -> "ModelHandle": + for part in self.parts: + part.train(mode) + return self + + def parameters(self) -> Iterator[torch.nn.Parameter]: + for part in self.parts: + yield from part.parameters() + + def named_parameters(self) -> Iterator[tuple[str, torch.nn.Parameter]]: + for part in self.parts: + yield from part.named_parameters() + + def buffers(self) -> Iterator[torch.Tensor]: + for part in self.parts: + yield from part.buffers() + + def to(self, device: str | torch.device) -> "ModelHandle": + for part in self.parts: + part.to(device) + return self + + def move_buffers_to(self, device: str | torch.device) -> None: + """Move buffers to device (FSDP modules don't move buffers automatically).""" + for part in self.parts: + for v in part.buffers(): + torch.utils.swap_tensors(v, v.to(device)) diff --git a/nemo_rl/models/automodel/pipeline_parallel.py b/nemo_rl/models/automodel/pipeline_parallel.py new file mode 100644 index 0000000000..0828fec0ec --- /dev/null +++ b/nemo_rl/models/automodel/pipeline_parallel.py @@ -0,0 +1,966 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline parallelism utilities for automodel-based training. + +Contains PP-specific classes and functions: +- Broadcast helpers for PP stage communication +- PPLossAdapter: stateful loss wrapper for PP schedules +- PPLogitsCapturer: logit capture for eval +- pp_forward_backward: PP schedule step/eval wrapper +- pp_forward_with_post_processing / pp_forward_loop: PP eval with post-processing +- reset_pp_stage_shapes_for_thd: THD stage shape management +- prepare_pp_seqpack_batch: pack sequences for PP training step +- pad_batch_for_pp: pad batch to pp_batch_size for PP schedule +""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: + from nemo_rl.models.automodel.data import THDBatch + from nemo_rl.models.automodel.train import ( + LogprobsPostProcessor, + TopkLogitsPostProcessor, + ) + +import torch +from torch import nn + +from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams +from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + +# --------------------------------------------------------------------------- +# Broadcast helpers +# --------------------------------------------------------------------------- + + +def broadcast_tensors_from_last_pp_stage( + tensors: dict[str, Optional[torch.Tensor]], + pp_mesh: Any, + has_last_stage: bool, +) -> dict[str, torch.Tensor]: + """Broadcast tensors from last PP stage to all stages. + + Adapted from nemo_rl/models/megatron/pipeline_parallel.py for device-mesh + based PP (no Megatron parallel state). + + Args: + tensors: Dict mapping names to tensors. On last stage, tensors are + populated; on other stages they may be None. + pp_mesh: Pipeline parallel DeviceMesh. + has_last_stage: Whether this rank owns the last pipeline stage. + + Returns: + Dict with the same keys, all tensors populated on every rank. + """ + pp_group = pp_mesh.get_group() + pp_ranks = torch.distributed.get_process_group_ranks(pp_group) + last_global_rank = pp_ranks[-1] + + result = {} + for name, tensor in tensors.items(): + # Broadcast metadata (shape, dtype) + if has_last_stage: + assert tensor is not None, f"Last stage must provide tensor '{name}'" + meta = [list(tensor.shape), str(tensor.dtype)] + else: + meta = [None, None] + meta_list = [meta] + torch.distributed.broadcast_object_list( + meta_list, src=last_global_rank, group=pp_group + ) + shape, dtype_str = meta_list[0] + + if not has_last_stage: + dtype = getattr(torch, dtype_str.replace("torch.", "")) + tensor = torch.empty(shape, dtype=dtype, device="cuda") + + torch.distributed.broadcast(tensor, src=last_global_rank, group=pp_group) + result[name] = tensor + + return result + + +def broadcast_loss_metrics_from_last_pp_stage( + metrics: Optional[list[dict[str, Any]]], + pp_mesh: Any, + has_last_stage: bool, +) -> list[dict[str, Any]]: + """Broadcast loss metrics from last PP stage to all stages. + + Args: + metrics: List of metric dicts (populated on last stage, None elsewhere). + pp_mesh: Pipeline parallel DeviceMesh. + has_last_stage: Whether this rank owns the last pipeline stage. + + Returns: + List of metric dicts on all ranks. + """ + pp_group = pp_mesh.get_group() + pp_ranks = torch.distributed.get_process_group_ranks(pp_group) + last_global_rank = pp_ranks[-1] + + obj = [metrics] + torch.distributed.broadcast_object_list(obj, src=last_global_rank, group=pp_group) + return obj[0] + + +# --------------------------------------------------------------------------- +# Stateful loss adapters +# --------------------------------------------------------------------------- + + +class PPLossAdapter: + """Stateful loss adapter for pipeline parallel schedules. + + Bridges NeMo RL's LossFunction interface with the PP schedule's + ``loss_fn(output, target)`` contract by pre-chunking RL data across + microbatches. + + The PP schedule calls ``__call__(output, target)`` once per microbatch. + This adapter indexes into pre-chunked RL tensors to compute the loss for + each microbatch. + + **Critical**: The loss is scaled by ``dp_size * cp_size`` to cancel FSDP's + automatic gradient averaging, matching the non-PP path. + """ + + def __init__( + self, + loss_fn: LossFunction, + cfg: Any, + device_mesh: Any, + cp_mesh: Any, + tp_mesh: Any, + cp_size: int, + dp_size: int, + enable_seq_packing: bool = False, + sampling_params: Optional[TrainingSamplingParams] = None, + ): + self._loss_fn = loss_fn + self._cfg = cfg + self._device_mesh = device_mesh + self._cp_mesh = cp_mesh + self._tp_mesh = tp_mesh + self._cp_size = cp_size + self._dp_size = dp_size + self._enable_seq_packing = enable_seq_packing + self._sampling_params = sampling_params + + self._microbatches: list[dict[str, Any]] = [] + self._cu_seqlens_list: list[Optional[torch.Tensor]] = [] + self._cu_seqlens_padded_list: list[Optional[torch.Tensor]] = [] + self._call_idx: int = 0 + self._all_metrics: list[dict[str, Any]] = [] + self._global_valid_seqs: Optional[torch.Tensor] = None + self._global_valid_toks: Optional[torch.Tensor] = None + self._num_global_batches: int = 1 + self._context_parallel_group: Any = None + + def set_microbatches( + self, + data_dict: Any, + n_microbatches: int, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + num_global_batches: int = 1, + cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens_padded: Optional[torch.Tensor] = None, + context_parallel_group: Any = None, + ) -> None: + """Pre-chunk RL tensors along batch dim into n_microbatches.""" + self._call_idx = 0 + self._all_metrics = [] + self._global_valid_seqs = global_valid_seqs + self._global_valid_toks = global_valid_toks + self._num_global_batches = num_global_batches + + self._microbatches = [] + for i in range(n_microbatches): + mb: dict[str, Any] = {} + for key in data_dict: + val = data_dict[key] + if torch.is_tensor(val) and val.shape[0] > 0: + chunks = torch.tensor_split(val, n_microbatches, dim=0) + mb[key] = chunks[i] + else: + mb[key] = val + self._microbatches.append(mb) + + # Set per-microbatch cu_seqlens for sequence packing loss. + self._context_parallel_group = context_parallel_group + + def _split_cu_seqlens(cu): + if cu is None: + return [None] * n_microbatches + if isinstance(cu, list): + return cu + if cu.ndim == 1: + return [cu] * n_microbatches + return [cu[i] for i in range(n_microbatches)] + + self._cu_seqlens_list = _split_cu_seqlens(cu_seqlens) + self._cu_seqlens_padded_list = _split_cu_seqlens(cu_seqlens_padded) + + def __call__(self, output: Any, target: torch.Tensor) -> torch.Tensor: + """Called by PP schedule per microbatch.""" + logits = getattr(output, "logits", output) + # THD format: [total_tokens, vocab] → [1, total_tokens, vocab] + if logits.ndim == 2: + logits = logits.unsqueeze(0) + mb_data = self._microbatches[self._call_idx] + cu_seqlens = ( + self._cu_seqlens_list[self._call_idx] if self._cu_seqlens_list else None + ) + cu_seqlens_padded = ( + self._cu_seqlens_padded_list[self._call_idx] + if self._cu_seqlens_padded_list + else cu_seqlens + ) + self._call_idx += 1 + + if cu_seqlens is not None: + if not isinstance(mb_data, BatchedDataDict): + mb_data = BatchedDataDict(mb_data) + + prepare_loss_input_wrapped = partial( + prepare_loss_input, sampling_params=self._sampling_params + ) + loss_fn_wrapped = SequencePackingLossWrapper( + loss_fn=self._loss_fn, + prepare_fn=prepare_loss_input_wrapped, + cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded + if cu_seqlens_padded is not None + else cu_seqlens, + context_parallel_group=self._context_parallel_group, + ) + loss, loss_metrics = loss_fn_wrapped( + logits, + mb_data, + self._global_valid_seqs, + self._global_valid_toks, + ) + else: + # Standard (non-packed) loss path + log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1) + curr_logprobs = ( + log_probs[:, :-1] + .gather(dim=-1, index=target[:, 1:].unsqueeze(-1).clamp(min=0)) + .squeeze(-1) + ) + + if self._loss_fn.input_type == LossInputType.LOGPROB: + loss_input = {"next_token_logprobs": curr_logprobs} + else: + loss_input = {"logits": logits} + + loss, loss_metrics = self._loss_fn( + data=mb_data, + global_valid_seqs=self._global_valid_seqs, + global_valid_toks=self._global_valid_toks, + **loss_input, + ) + + # Scale metrics for aggregation + for k in loss_metrics: + if "_min" not in k and "_max" not in k: + loss_metrics[k] /= self._num_global_batches + + self._all_metrics.append(loss_metrics) + + # Scale loss to cancel FSDP's automatic gradient averaging + return loss * self._dp_size * self._cp_size + + def reset(self) -> None: + """Reset state for the next forward-backward call.""" + self._call_idx = 0 + self._all_metrics = [] + self._cu_seqlens_list = [] + self._cu_seqlens_padded_list = [] + + +# --------------------------------------------------------------------------- +# Logit/logprob capturers for eval +# --------------------------------------------------------------------------- + + +class PPLogitsCapturer: + """Pseudo-loss that captures logits from each PP microbatch during eval. + + Used with ``schedule.eval()`` to collect logits on the last pipeline + stage for downstream post-processing (logprobs, top-k, etc.). + """ + + def __init__(self): + self.captured_logits: list[torch.Tensor] = [] + + def __call__(self, output: Any, target: torch.Tensor) -> torch.Tensor: + logits = getattr(output, "logits", output) + self.captured_logits.append(logits.detach()) + return torch.tensor(0.0, device="cuda") + + def reset(self) -> None: + self.captured_logits = [] + + +# --------------------------------------------------------------------------- +# PP forward/backward +# --------------------------------------------------------------------------- + + +def pp_forward_backward( + model: nn.Module, + batch: dict[str, torch.Tensor], + loss_adapter: PPLossAdapter, + *, + forward_only: bool = False, +) -> tuple[torch.Tensor, list[dict[str, Any]]]: + """Execute forward (and optionally backward) using the PP schedule. + + The PP schedule internally handles microbatch splitting, stage-to-stage + communication, and gradient accumulation. + + Args: + model: AutoPipeline model with .info and .parts attributes. + batch: Dict with at least ``input_ids`` and ``labels``. + May also contain ``flash_attn_kwargs``, ``position_ids``, etc. + for sequence packing support. Extra keys are passed as kwargs + to schedule.step/eval following automodel's train_ft.py pattern. + loss_adapter: PPLossAdapter (already configured via set_microbatches). + forward_only: If True, use schedule.eval() instead of schedule.step(). + + Returns: + Tuple of (total_loss, list_of_metric_dicts). + total_loss is the summed loss on last stage, 0.0 on other stages. + """ + schedule = model.info.schedule + has_first = model.info.has_first_stage + has_last = model.info.has_last_stage + + input_ids = batch.pop("input_ids") + targets = batch.pop("labels", None) if has_last else None + losses: Optional[list[torch.Tensor]] = [] if has_last else None + + # Inject the loss adapter into the schedule + schedule._loss_fn = loss_adapter + + # Build args: first stage receives input_ids, others don't + args = (input_ids,) if has_first else () + + # Pass remaining batch keys (flash_attn_kwargs, position_ids, etc.) + # as kwargs to the schedule, following automodel's train_ft.py pattern. + # Filter out None values and empty dicts to avoid PP chunking errors. + batch_kwargs = { + k: v + for k, v in batch.items() + if v is not None and not (isinstance(v, dict) and len(v) == 0) + } + + if forward_only: + schedule.eval(*args, target=targets, losses=losses, **batch_kwargs) + else: + schedule.step(*args, target=targets, losses=losses, **batch_kwargs) + + if has_last and losses: + total_loss = torch.sum(torch.stack(losses)) + else: + total_loss = torch.tensor(0.0, device="cuda") + + return total_loss, loss_adapter._all_metrics + + +# --------------------------------------------------------------------------- +# THD stage shape management +# --------------------------------------------------------------------------- + + +def reset_pp_stage_shapes_for_thd(model: Any, tokens_per_chunk: int) -> None: + """Reset PP stage shapes for THD format (packed sequences). + + THD format produces [1, T, dim] outputs instead of [batch, seq, dim]. + Must be called before each schedule.step() when sequence lengths change. + """ + from nemo_automodel.components.distributed.pipelining.functional import ( + _get_hidden_and_vocab_size, + ) + + schedule = model.info.schedule + stages = model.info.stages + model_config = model.parts[0].config + hidden_size, vocab_size = _get_hidden_and_vocab_size(model_config) + + schedule._stages_forward_initialized = False + if hasattr(schedule, "_stages_backward_initialized"): + schedule._stages_backward_initialized = False + + for stage in stages: + try: + model_dtype = next(stage.submod.parameters()).dtype + except StopIteration: + model_dtype = torch.bfloat16 + + if stage.is_first: + stage.inputs_meta = ( + torch.empty(1, tokens_per_chunk, device="meta", dtype=torch.long), + ) + else: + stage.inputs_meta = ( + torch.empty( + 1, tokens_per_chunk, hidden_size, device="meta", dtype=model_dtype + ), + ) + + has_lm_head = ( + hasattr(stage.submod, "lm_head") and stage.submod.lm_head is not None + ) + out_dim = vocab_size if has_lm_head else hidden_size + stage._outputs_meta = ( + torch.empty(1, tokens_per_chunk, out_dim, device="meta", dtype=model_dtype), + ) + + +def _reset_pp_schedule_state( + model: Any, + seq_len: int, + *, + seqpack: bool = False, + is_hf_model: bool = False, + force: bool = False, +) -> None: + """Reset PP schedule state for the upcoming batch. + + For THD format (seqpack with custom models), delegates to + ``reset_pp_stage_shapes_for_thd``. Otherwise delegates to + ``model.update_seq_len()`` which skips when seq_len is unchanged. + + Args: + force: If True, clear cached seq_len to force re-initialization. + Required when switching between schedule.eval() and + schedule.step() since the schedule needs fresh state. + """ + if not hasattr(model, "update_seq_len"): + raise RuntimeError( + "AutoPipeline.update_seq_len() not found. " + "Pipeline parallelism requires nemo_automodel >= 0.4.0 " + "(PR #1689: dynamic sequence length support). " + "Please update the automodel submodule to latest main." + ) + if force: + model._pp_current_seq_len = None + if seqpack and not is_hf_model: + model._pp_current_seq_len = None + reset_pp_stage_shapes_for_thd(model, seq_len) + else: + model.update_seq_len(seq_len) + + +# --------------------------------------------------------------------------- +# PP seqpack batch preparation +# --------------------------------------------------------------------------- + + +def prepare_pp_seqpack_batch( + input_ids: torch.Tensor, + input_lengths: torch.Tensor, + accum_data: dict[str, Any], + pp_batch_size: int, + n_microbatches: int, + tokenizer_eos_id: int, + train_mb_tokens: int, + cp_size: int = 1, + cp_mesh: Any = None, + device_mesh: Any = None, + token_mask: Optional[torch.Tensor] = None, +) -> tuple[dict[str, Any], dict[str, Any], "THDBatch"]: + """Pack sequences into THD format for a PP training step. + + Handles dummy-padding when the batch has fewer than pp_batch_size sequences. + Returns the PP batch dict (for schedule.step), the updated accum_data + (with dummy sample_mask=0), and the THDBatch for loss metadata. + """ + from nemo_rl.models.automodel.data import pack_for_thd + + actual_seqs = input_ids.shape[0] + pp_mbs = pp_batch_size // n_microbatches + + if actual_seqs < pp_batch_size: + pad_count = pp_batch_size - actual_seqs + input_ids = torch.cat( + [ + input_ids, + torch.zeros( + pad_count, + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ), + ], + dim=0, + ) + input_lengths = torch.cat( + [ + input_lengths, + torch.ones( + pad_count, dtype=input_lengths.dtype, device=input_lengths.device + ), + ] + ) + for key in accum_data: + val = accum_data[key] + if torch.is_tensor(val) and val.ndim >= 1 and val.shape[0] == actual_seqs: + pad_shape = (pad_count,) + val.shape[1:] + accum_data[key] = torch.cat( + [val, torch.zeros(pad_shape, dtype=val.dtype, device=val.device)] + ) + if "sample_mask" in accum_data: + accum_data["sample_mask"][actual_seqs:] = 0 + + if token_mask is not None: + if token_mask.shape[0] < len(input_lengths): + pad_rows = len(input_lengths) - token_mask.shape[0] + token_mask = torch.cat( + [ + token_mask, + torch.zeros( + pad_rows, + token_mask.shape[1], + dtype=token_mask.dtype, + device=token_mask.device, + ), + ], + dim=0, + ) + token_mask = token_mask[: len(input_lengths)] + + thd_batch = pack_for_thd( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[pp_mbs] * n_microbatches, + padding_value=tokenizer_eos_id, + min_seq_len=train_mb_tokens, + num_chunks=n_microbatches, + cp_size=cp_size, + cp_mesh=cp_mesh, + device_mesh=device_mesh, + token_mask=token_mask, + ) + pp_batch = thd_batch.to_model_kwargs(device=input_ids.device) + return pp_batch, accum_data, thd_batch + + +def pad_batch_for_pp( + input_ids: torch.Tensor, + pp_batch_size: int, +) -> tuple[torch.Tensor, int]: + """Pad input_ids to pp_batch_size with zero rows. Returns (padded_ids, actual_seqs).""" + actual_seqs = input_ids.shape[0] + if actual_seqs < pp_batch_size: + pad_count = pp_batch_size - actual_seqs + input_ids = torch.cat( + [ + input_ids, + torch.zeros( + pad_count, + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ), + ], + dim=0, + ) + return input_ids, actual_seqs + + +# --------------------------------------------------------------------------- +# PP forward with post-processing (eval path) +# --------------------------------------------------------------------------- + + +def pp_forward_with_post_processing( + model: nn.Module, + input_ids: torch.Tensor, + chunk_data: BatchedDataDict, + post_processor: Union["LogprobsPostProcessor", "TopkLogitsPostProcessor"], + pp_batch_size: int, + seq_dim_size: int, + pp_mesh: Any, + has_first_stage: bool, + has_last_stage: bool, + capturer: PPLogitsCapturer, +) -> dict[str, torch.Tensor]: + """Run PP schedule.eval() on one chunk and apply post-processing. + + Analogous to ``forward_with_post_processing_fn()`` in ``train.py`` but + for the PP eval path. Handles batch padding, schedule execution, logit + capture, broadcasting, and post-processor dispatch. + + Broadcast timing differs by post-processor type: + - LogprobsPostProcessor: broadcasts full logits first (post-processor + needs full vocab for log_softmax + gather), then post-processes. + - TopkLogitsPostProcessor: applies top-k on last stage first (reduces + [B,S,V] → [B,S,k]), then broadcasts the smaller tensors. + + Args: + model: AutoPipeline model with .info attribute. + input_ids: Token IDs for this chunk [actual_seqs, seq_len]. + chunk_data: Sliced BatchedDataDict for this chunk. + post_processor: LogprobsPostProcessor or TopkLogitsPostProcessor. + pp_batch_size: Required batch size for the PP schedule. + seq_dim_size: Global max sequence length (for output padding). + pp_mesh: Pipeline parallel DeviceMesh. + has_first_stage: Whether this rank owns the first pipeline stage. + has_last_stage: Whether this rank owns the last pipeline stage. + capturer: PPLogitsCapturer instance (shared across chunks). + + Returns: + Dict of result tensors, e.g. ``{"logprobs": tensor}`` or + ``{"topk_logits": tensor, "topk_indices": tensor}``. + """ + from nemo_rl.models.automodel.data import ProcessedInputs + from nemo_rl.models.automodel.train import ( + LogprobsPostProcessor, + TopkLogitsPostProcessor, + ) + + schedule = model.info.schedule + + # 1. Pad batch and reset schedule state + capturer.reset() + input_ids, actual_seqs = pad_batch_for_pp(input_ids, pp_batch_size) + labels = input_ids.clone() + + _reset_pp_schedule_state(model, input_ids.shape[-1]) + + # 2. Run schedule.eval() + targets = labels if has_last_stage else None + losses = [] if has_last_stage else None + args = (input_ids,) if has_first_stage else () + schedule.eval(*args, target=targets, losses=losses) + + # 3. Dispatch based on post-processor type + if isinstance(post_processor, LogprobsPostProcessor): + # Broadcast full logits, then post-process on all ranks + if has_last_stage: + all_logits = torch.cat(capturer.captured_logits, dim=0) + tensors = {"logits": all_logits} + else: + tensors = {"logits": None} + + broadcasted = broadcast_tensors_from_last_pp_stage( + tensors, pp_mesh, has_last_stage + ) + logits = broadcasted["logits"] + + processed_inputs = ProcessedInputs( + input_ids=input_ids, + seq_len=input_ids.shape[1], + ) + token_logprobs = post_processor( + logits=logits, + data_dict=chunk_data, + processed_inputs=processed_inputs, + original_batch_size=actual_seqs, + original_seq_len=seq_dim_size, + sequence_dim=1, + ) + + # Trim dummy-padded rows and pad to global seq dim + token_logprobs = token_logprobs[:actual_seqs] + padding_needed = seq_dim_size - token_logprobs.shape[1] + if padding_needed > 0: + token_logprobs = torch.nn.functional.pad( + token_logprobs, (0, padding_needed), mode="constant", value=0.0 + ) + return {"logprobs": token_logprobs} + + elif isinstance(post_processor, TopkLogitsPostProcessor): + # Post-process on last stage first (reduces vocab dim), then broadcast + k = post_processor.k + if has_last_stage: + mb_vals, mb_idx = [], [] + for logits in capturer.captured_logits: + vals, idx = torch.topk(logits.float(), k=k, dim=-1) + mb_vals.append(vals) + mb_idx.append(idx) + topk_vals = torch.cat(mb_vals, dim=0) + topk_indices = torch.cat(mb_idx, dim=0) + tensors = {"topk_logits": topk_vals, "topk_indices": topk_indices} + else: + tensors = {"topk_logits": None, "topk_indices": None} + + broadcasted = broadcast_tensors_from_last_pp_stage( + tensors, pp_mesh, has_last_stage + ) + vals = broadcasted["topk_logits"] + idx = broadcasted["topk_indices"] + + # Pad to global seq dim + pad_needed = seq_dim_size - vals.shape[1] + if pad_needed > 0: + vals = torch.nn.functional.pad(vals, (0, 0, 0, pad_needed), value=0.0) + idx = torch.nn.functional.pad(idx, (0, 0, 0, pad_needed), value=0) + return {"topk_logits": vals, "topk_indices": idx} + + else: + raise TypeError( + f"Unsupported post-processor type for PP eval: {type(post_processor)}" + ) + + +def pp_forward_loop( + model: nn.Module, + data: BatchedDataDict, + post_processor: Union["LogprobsPostProcessor", "TopkLogitsPostProcessor"], + pp_batch_size: int, + seq_dim_size: int, + pp_mesh: Any, + has_first_stage: bool, + has_last_stage: bool, +) -> dict[str, torch.Tensor]: + """Run PP eval over all data in pp_batch_size chunks with post-processing. + + Analogous to how the non-PP path loops over microbatches calling + ``forward_with_post_processing_fn()`` per microbatch. + + Handles syncing total_samples across ranks (to prevent collective hangs + with uneven data), iterating chunks, and concatenating results. + + Args: + model: AutoPipeline model. + data: Full dataset (will be moved to CUDA). + post_processor: LogprobsPostProcessor or TopkLogitsPostProcessor. + pp_batch_size: Batch size per PP schedule eval call. + seq_dim_size: Global max sequence length. + pp_mesh: Pipeline parallel DeviceMesh. + has_first_stage: Whether this rank owns the first pipeline stage. + has_last_stage: Whether this rank owns the last pipeline stage. + + Returns: + Dict of concatenated result tensors across all chunks. + """ + capturer = PPLogitsCapturer() + schedule = model.info.schedule + schedule._loss_fn = capturer + + data.to("cuda") + all_input_ids = data.get("input_ids").cuda() + total_samples = all_input_ids.shape[0] + + # Sync across ranks to prevent collective hangs when data is uneven + max_samples = torch.tensor([total_samples], device="cuda") + torch.distributed.all_reduce(max_samples, op=torch.distributed.ReduceOp.MAX) + num_chunks = max(1, (int(max_samples.item()) + pp_batch_size - 1) // pp_batch_size) + + all_results: dict[str, list[torch.Tensor]] = {} + for chunk_idx in range(num_chunks): + start = chunk_idx * pp_batch_size + actual_chunk_size = min(pp_batch_size, total_samples - start) + chunk_data = data.slice(start, start + actual_chunk_size) + input_ids = chunk_data.get("input_ids").cuda() + + chunk_result = pp_forward_with_post_processing( + model=model, + input_ids=input_ids, + chunk_data=chunk_data, + post_processor=post_processor, + pp_batch_size=pp_batch_size, + seq_dim_size=seq_dim_size, + pp_mesh=pp_mesh, + has_first_stage=has_first_stage, + has_last_stage=has_last_stage, + capturer=capturer, + ) + + for key, tensor in chunk_result.items(): + if key not in all_results: + all_results[key] = [] + all_results[key].append(tensor) + + # Concatenate across chunks + return {key: torch.cat(tensors, dim=0) for key, tensors in all_results.items()} + + +# --------------------------------------------------------------------------- +# PP training forward/backward loop +# --------------------------------------------------------------------------- + + +def pp_train_forward_backward_loop( + *, + model: nn.Module, + model_parts: list[nn.Module], + batch: BatchedDataDict, + loss_adapter: PPLossAdapter, + pp_batch_size: int, + n_microbatches: int, + has_last_stage: bool, + pp_mesh: Any, + enable_seq_packing: bool, + is_hf_model: bool, + tokenizer_eos_id: int, + train_mb_tokens: int, + cp_size: int, + cp_mesh: Any, + device_mesh: Any, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + num_global_batches: int, + forward_only: bool = False, +) -> tuple[torch.Tensor, list[dict[str, Any]]]: + """Run the PP gradient-accumulation loop for one global batch. + + Encapsulates the inner accumulation loop of PP training, analogous to + ``automodel_forward_backward()`` in ``train.py`` for the non-PP path. + + Handles: + - Syncing actual batch size across ranks for uniform ``num_accum_steps`` + - ``prepare_for_grad_accumulation`` / ``prepare_for_final_backward`` / + ``prepare_after_first_microbatch`` lifecycle calls + - Slicing the batch into accumulation-step-sized chunks + - Building PP batches (seqpack via ``prepare_pp_seqpack_batch`` or simple) + - Configuring the loss adapter per accumulation step + - Resetting PP schedule state (seq len / THD shapes) + - Calling ``pp_forward_backward`` per accumulation step + - Broadcasting and collecting metrics from last stage + + Gradient clipping, optimizer step, and loss broadcasting for reporting + are NOT included -- those remain in the caller (policy-level concerns). + + Returns: + Tuple of (total_loss, all_mb_metrics). + total_loss: accumulated loss across accum steps (on last stage; 0 elsewhere). + all_mb_metrics: list of metric dicts collected from all accum steps. + """ + from nemo_automodel.components.training.utils import ( + prepare_after_first_microbatch, + prepare_for_final_backward, + prepare_for_grad_accumulation, + ) + + # Sync actual batch size across all ranks so everyone has the same + # num_accum_steps (required for PP collectives). + actual_total = batch.get("input_ids").shape[0] + max_batch = torch.tensor([actual_total], device="cuda") + torch.distributed.all_reduce(max_batch, op=torch.distributed.ReduceOp.MAX) + num_accum_steps = max( + 1, (int(max_batch.item()) + pp_batch_size - 1) // pp_batch_size + ) + + if not forward_only: + prepare_for_grad_accumulation(model_parts, pp_enabled=True) + + total_loss = torch.tensor(0.0, device="cuda") + all_metrics: list[dict[str, Any]] = [] + + for accum_idx in range(num_accum_steps): + if not forward_only and accum_idx == num_accum_steps - 1: + prepare_for_final_backward(model_parts, pp_enabled=True) + + # Clamp to actual batch size so empty/partial accum steps on ranks + # with fewer samples still participate in collectives. + start = accum_idx * pp_batch_size + cstart = min(start, actual_total) + cend = min(start + pp_batch_size, actual_total) + accum_data = {} + for key in batch: + val = batch[key] + if torch.is_tensor(val) and val.shape[0] >= actual_total: + accum_data[key] = val[cstart:cend] + else: + accum_data[key] = val + + input_ids = batch.get("input_ids")[cstart:cend].cuda() + + thd_batch = None + if enable_seq_packing: + input_lengths = batch["input_lengths"][cstart:cend] + token_mask = accum_data.get("token_mask", None) + pp_batch, accum_data, thd_batch = prepare_pp_seqpack_batch( + input_ids=input_ids, + input_lengths=input_lengths, + accum_data=accum_data, + pp_batch_size=pp_batch_size, + n_microbatches=n_microbatches, + tokenizer_eos_id=tokenizer_eos_id, + train_mb_tokens=train_mb_tokens, + cp_size=cp_size, + cp_mesh=cp_mesh, + device_mesh=device_mesh, + token_mask=token_mask, + ) + else: + pp_batch = { + "input_ids": input_ids, + "labels": input_ids.clone(), + } + + if has_last_stage: + cu_seqlens_per_row = ( + thd_batch.cu_seqlens_per_row if enable_seq_packing else None + ) + cu_padded = ( + thd_batch.cu_seqlens_padded_per_row + if enable_seq_packing and cp_size > 1 + else None + ) + cp_group = cp_mesh.get_group() if cp_size > 1 else None + loss_adapter.set_microbatches( + accum_data, + n_microbatches, + global_valid_seqs, + global_valid_toks, + num_global_batches=num_global_batches, + cu_seqlens=cu_seqlens_per_row, + cu_seqlens_padded=cu_padded, + context_parallel_group=cp_group, + ) + else: + loss_adapter.reset() + + _reset_pp_schedule_state( + model, + pp_batch["input_ids"].shape[-1], + seqpack=enable_seq_packing, + is_hf_model=is_hf_model, + force=True, + ) + + step_loss, mb_metrics_list = pp_forward_backward( + model=model, + batch=pp_batch, + loss_adapter=loss_adapter, + forward_only=forward_only, + ) + total_loss += step_loss.detach() + + if not forward_only and accum_idx == 0: + prepare_after_first_microbatch() + + if has_last_stage: + gb_loss_metrics = mb_metrics_list + else: + gb_loss_metrics = None + gb_loss_metrics = broadcast_loss_metrics_from_last_pp_stage( + gb_loss_metrics, pp_mesh, has_last_stage + ) + if gb_loss_metrics: + all_metrics.extend(gb_loss_metrics) + + return total_loss, all_metrics diff --git a/nemo_rl/models/automodel/setup.py b/nemo_rl/models/automodel/setup.py index a321df7a5f..9441501a47 100644 --- a/nemo_rl/models/automodel/setup.py +++ b/nemo_rl/models/automodel/setup.py @@ -278,7 +278,6 @@ def validate_and_prepare_config( "Please set policy.sequence_packing.enabled = False to train VLM models." ) print(f"[Rank {rank}] Sequence packing is enabled for model {model_name}") - print(f"[Rank {rank}] Using FlashAttention2 for sequence packing") # Get HF config overrides hf_config_overrides = config.get("hf_config_overrides", {}) or {} @@ -293,18 +292,19 @@ def validate_and_prepare_config( # so we need to set it to None if sequence packing is disabled # See https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 cp_size_cfg = config["dtensor_cfg"]["context_parallel_size"] - attn_impl = ( - "flash_attention_2" - if (enable_seq_packing and cp_size_cfg == 1) - else ("sdpa" if cp_size_cfg > 1 else None) - ) + if enable_seq_packing: + attn_impl = "flash_attention_2" + elif cp_size_cfg > 1: + attn_impl = "sdpa" + else: + attn_impl = None # Load model config model_config = AutoConfig.from_pretrained( model_name, torch_dtype=torch.float32, # Always load in float32 for master weights trust_remote_code=True, - attn_implementation="flash_attention_2" if enable_seq_packing else None, + attn_implementation=attn_impl, **hf_config_overrides, ) @@ -347,13 +347,6 @@ def validate_and_prepare_config( cp_size = config["dtensor_cfg"].get("context_parallel_size", 1) sequence_parallel_enabled = config["dtensor_cfg"]["sequence_parallel"] - # Validate parallelization configuration - if cp_size > 1 and enable_seq_packing: - raise ValueError( - "Context parallel is not supported for sequence packing. " - "Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." - ) - if sequence_parallel_enabled and tp_size == 1: print( "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. " @@ -396,7 +389,10 @@ def setup_reference_model_state( >>> model = setup_model(...) >>> reference_model_state_dict = setup_reference_model_state(model) """ - return get_cpu_state_dict(model.state_dict().items(), pin_memory=True) + from nemo_rl.models.automodel.model_handle import ModelHandle + + handle = model if isinstance(model, ModelHandle) else ModelHandle(model) + return get_cpu_state_dict(handle.state_dict_items(), pin_memory=True) def setup_distributed( @@ -428,6 +424,7 @@ def setup_distributed( tp_size = config["dtensor_cfg"].get("tensor_parallel_size", 1) cp_size = config["dtensor_cfg"].get("context_parallel_size", 1) ep_size = config["dtensor_cfg"].get("expert_parallel_size", 1) + pp_size = config["dtensor_cfg"].get("pipeline_parallel_size", 1) sequence_parallel_enabled = config["dtensor_cfg"]["sequence_parallel"] # Build tp_plan from custom_parallel_plan config if set, else None (auto-select) @@ -459,11 +456,11 @@ def setup_distributed( "If you need this feature, please file an issue on https://github.com/NVIDIA-NeMo/Automodel." ) - # Create device meshes (dp_size is derived from world_size / (tp * cp * ep)) + # Create device meshes (dp_size is derived from world_size / (pp * tp * cp * ep)) device_mesh, moe_mesh = create_device_mesh( fsdp2_config, tp_size=tp_size, - pp_size=1, + pp_size=pp_size, cp_size=cp_size, ep_size=ep_size, world_size=world_size, @@ -473,6 +470,8 @@ def setup_distributed( resolved_dp_size = device_mesh["dp"].size() resolved_tp_size = device_mesh["tp"].size() resolved_cp_size = device_mesh["cp"].size() + resolved_pp_size = device_mesh["pp"].size() if pp_size > 1 else 1 + pp_mesh = device_mesh["pp"] if pp_size > 1 else None return DistributedContext( device_mesh=device_mesh, @@ -482,9 +481,67 @@ def setup_distributed( dp_size=resolved_dp_size, tp_size=resolved_tp_size, cp_size=resolved_cp_size, + pp_size=resolved_pp_size, + pp_mesh=pp_mesh, ) +def build_pipeline_config( + automodel_kwargs: dict, + config: "PolicyConfig", + model_config: Any, +) -> Any: + """Build and configure PipelineConfig from automodel_kwargs. + + Handles Hydra _target_ resolution, pp_batch_size computation, and + custom model detection for forward patching. + + Returns: + Resolved PipelineConfig, or None if no pipeline_config in kwargs. + """ + if automodel_kwargs.get("pipeline_config") is None: + return None + + pipeline_class = _resolve_target(automodel_kwargs["pipeline_config"]["_target_"]) + pipeline_kwargs = { + k: v for k, v in automodel_kwargs["pipeline_config"].items() if k != "_target_" + } + + # pp_batch_size = what the schedule processes per step() call. + # Equals train_micro_batch_size (fits in GPU memory). Gradient accumulation + # handles the rest (local_gbs / pp_batch_size steps per optimizer update). + if "pp_batch_size" not in pipeline_kwargs or pipeline_kwargs["pp_batch_size"] <= 1: + pipeline_kwargs["pp_batch_size"] = config["train_micro_batch_size"] + + # Custom nemo_automodel models (GPT-OSS, Qwen3 MoE, etc.) have their own + # forward() that handles rotary embeddings, attention masks, and layer + # iteration. Disable the generic HF pipeline_forward patching so the + # custom model's forward is preserved during PP. + is_custom_model = ( + model_config.architectures[0] in ModelRegistry.model_arch_name_to_cls + ) + if is_custom_model: + pipeline_kwargs.setdefault("patch_inner_model", False) + pipeline_kwargs.setdefault("patch_causal_lm_model", False) + + # Provide a dummy loss_fn — the PP schedule requires it at build time, + # but NeMo RL injects the real loss via schedule._loss_fn before each step. + if "loss_fn" not in pipeline_kwargs or pipeline_kwargs["loss_fn"] is None: + + def _dummy_loss(output, target): + logits = getattr(output, "logits", output) + return torch.nn.functional.cross_entropy( + logits.view(-1, logits.size(-1)).float(), + target.view(-1), + ignore_index=-100, + reduction="sum", + ) + + pipeline_kwargs["loss_fn"] = _dummy_loss + + return pipeline_class(**pipeline_kwargs) + + def setup_model_and_optimizer( config: PolicyConfig, tokenizer: AutoTokenizer, @@ -582,6 +639,11 @@ def setup_model_and_optimizer( backend = backend_class(**backend_kwargs) automodel_kwargs["backend"] = backend + # Resolve pipeline_config if present + pipeline_config = build_pipeline_config(automodel_kwargs, config, model_config) + if pipeline_config is not None: + automodel_kwargs["pipeline_config"] = pipeline_config + if "use_liger_kernel" not in automodel_kwargs: automodel_kwargs["use_liger_kernel"] = False @@ -650,8 +712,13 @@ def setup_model_and_optimizer( print(model) - # Compute model metadata after from_pretrained - model_state_dict_keys = list(model.state_dict().keys()) + # Wrap in ModelHandle for unified PP/non-PP interface + from nemo_rl.models.automodel.model_handle import ModelHandle + + model_handle = ModelHandle(model) + + # Compute model metadata + model_state_dict_keys = model_handle.state_dict_keys() is_moe_model = any(["expert" in key for key in model_state_dict_keys]) is_hf_model = ( model_config.architectures[0] not in ModelRegistry.model_arch_name_to_cls @@ -659,70 +726,88 @@ def setup_model_and_optimizer( # Autocast is disabled for custom MoE models (non-HF) to avoid numerical issues autocast_enabled = not (is_moe_model and not is_hf_model) - # Set pad token ID if needed. Some model configs (e.g. Gemma3 in transformers v5) - # don't have pad_token_id as a direct attribute. - if getattr(model.config, "pad_token_id", None) is None: - model.config.pad_token_id = tokenizer.pad_token_id - - # Handle tied word embeddings (safety net after from_pretrained) - is_tied_lm_head = hasattr(model, "lm_head") and getattr( - getattr(model, "config", {}), "tie_word_embeddings", False - ) - if is_tied_lm_head: - model.tie_weights() + # Set pad token ID if needed + for part in model_handle.parts: + if ( + hasattr(part, "config") + and getattr(part.config, "pad_token_id", None) is None + ): + part.config.pad_token_id = tokenizer.pad_token_id + + # Handle tied word embeddings (PP validates tie_word_embeddings=False) + if not model_handle.pp_enabled: + is_tied_lm_head = hasattr(model, "lm_head") and getattr( + getattr(model, "config", {}), "tie_word_embeddings", False + ) + if is_tied_lm_head: + model.tie_weights() # CPU offload if needed if cpu_offload: - # Move buffers to CPU for FSDP modules - for v in model.buffers(): - v.data = v.data.to("cpu") - model = model.to("cpu") + model_handle.move_buffers_to("cpu") + model_handle.to("cpu") - # Initialize optimizer - optimizer = None + # Initialize optimizers — always a list (one per model part) + optimizers = None if init_optimizer: optimizer_cls = get_class(config["optimizer"]["name"]) - optimizer = optimizer_cls(model.parameters(), **config["optimizer"]["kwargs"]) + optimizers = [ + optimizer_cls( + [p for p in part.parameters() if p.requires_grad], + **config["optimizer"]["kwargs"], + ) + for part in model_handle.parts + ] + + # Initialize schedulers — one per optimizer + schedulers = None - # Initialize scheduler - scheduler = None - if "scheduler" in config and optimizer is not None: + def _make_scheduler_for_optimizer(opt): + """Create a scheduler for a single optimizer based on config.""" + if "scheduler" not in config: + return torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda epoch: 1) if isinstance(config["scheduler"], dict): - scheduler_cls = get_class(config["scheduler"]["name"]) - scheduler = scheduler_cls(optimizer, **config["scheduler"]["kwargs"]) + sched_cls = get_class(config["scheduler"]["name"]) + return sched_cls(opt, **config["scheduler"]["kwargs"]) else: - schedulers = [] - for scheduler_cfg in config["scheduler"]: - if "name" in scheduler_cfg: - schedulers.append( - get_class(scheduler_cfg["name"])( - optimizer, **scheduler_cfg["kwargs"] - ) + scheds = [] + for sched_cfg in config["scheduler"]: + if "name" in sched_cfg: + scheds.append( + get_class(sched_cfg["name"])(opt, **sched_cfg["kwargs"]) ) else: - assert "milestones" in scheduler_cfg, ( + assert "milestones" in sched_cfg, ( "unknown scheduler config: ", - scheduler_cfg, + sched_cfg, ) - milestones: list[int] = scheduler_cfg["milestones"] + milestones_val: list[int] = sched_cfg["milestones"] + return torch.optim.lr_scheduler.SequentialLR(opt, scheds, milestones_val) - scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, schedulers, milestones - ) - elif optimizer is not None: - # Default to passthrough LR schedule - scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=lambda epoch: 1 - ) + if optimizers is not None: + schedulers = [_make_scheduler_for_optimizer(opt) for opt in optimizers] - # Load NeMo RL checkpoint if provided + # Load NeMo RL checkpoint if provided. + # For PP, pass model parts list and full optimizer/scheduler lists + # so the automodel Checkpointer can load per-stage state. if weights_path: + model_for_ckpt = model_handle.parts if model_handle.pp_enabled else model + opt_for_ckpt = ( + optimizers + if model_handle.pp_enabled + else (optimizers[0] if optimizers else None) + ) + sched_for_ckpt = ( + schedulers + if model_handle.pp_enabled + else (schedulers[0] if schedulers else None) + ) checkpoint_manager.load_checkpoint( - model=model, + model=model_for_ckpt, weights_path=weights_path, - optimizer=optimizer, + optimizer=opt_for_ckpt, optimizer_path=optimizer_path, - scheduler=scheduler, + scheduler=sched_for_ckpt, ) else: print( @@ -730,14 +815,14 @@ def setup_model_and_optimizer( ) return ModelAndOptimizerState( - model=model, - optimizer=optimizer, - scheduler=scheduler, + model=model_handle, + optimizers=optimizers, + schedulers=schedulers, is_hf_model=is_hf_model, is_moe_model=is_moe_model, is_reward_model=is_reward_model, model_class=type(model), - model_config=model.config, + model_config=model_handle.config, peft_config=peft_config, autocast_enabled=autocast_enabled, ) diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index d2b5979400..039cf9f340 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -49,6 +49,7 @@ allgather_cp_sharded_tensor, distributed_vocab_topk, get_logprobs_from_vocab_parallel_logits, + get_next_token_logprobs_from_logits, ) from nemo_rl.models.automodel.data import ProcessedInputs, ProcessedMicrobatch from nemo_rl.models.policy import PolicyConfig @@ -62,6 +63,22 @@ ] + +def _is_custom_automodel(model: nn.Module) -> bool: + """Check if model is a custom nemo_automodel model (not standard HF). + + Custom models (gpt-oss, qwen3-moe, etc.) use **attn_kwargs and expect + individual keys like cu_seqlens, not a bundled flash_attn_kwargs object. + """ + try: + from nemo_automodel._transformers.registry import ModelRegistry + + arch = getattr(getattr(model, "config", None), "architectures", [None])[0] + return arch is not None and arch in ModelRegistry.model_arch_name_to_cls + except (ImportError, IndexError, AttributeError): + return False + + def model_forward( model: nn.Module, processed_inputs: ProcessedInputs, @@ -79,6 +96,14 @@ def model_forward( Returns: torch.Tensor: Output tensor from the model (logits) """ + if processed_inputs.thd_batch is not None: + thd = processed_inputs.thd_batch + model_args = thd.to_model_kwargs(device=processed_inputs.input_ids.device) + model_args.pop("labels", None) # labels are for loss, not model forward + model_args["use_cache"] = False + outputs = model(**model_args) + return outputs + model_args = dict( input_ids=processed_inputs.input_ids, attention_mask=processed_inputs.attention_mask, @@ -86,9 +111,22 @@ def model_forward( use_cache=False, ) - # Add flash attention kwargs if applicable + # Custom automodel models use **attn_kwargs and need cu_seqlens + qkv_format + # directly. HF models expect a bundled flash_attn_kwargs object. if processed_inputs.has_flash_attention: - model_args["flash_attn_kwargs"] = processed_inputs.flash_attn_kwargs + fa_kwargs = processed_inputs.flash_attn_kwargs + if _is_custom_automodel(model): + # Flatten batch dim: [1, packed_len] -> [packed_len] so the model + # produces 2D hidden states that TE auto-detects as THD format. + model_args["input_ids"] = model_args["input_ids"].squeeze(0) + if model_args["position_ids"] is not None: + model_args["position_ids"] = model_args["position_ids"].squeeze(0) + model_args["cu_seqlens"] = fa_kwargs.cu_seqlens_q.to( + dtype=torch.int32, device=model_args["input_ids"].device + ) + model_args["qkv_format"] = "thd" + else: + model_args["flash_attn_kwargs"] = fa_kwargs # Add VLM kwargs if applicable if processed_inputs.is_multimodal: @@ -127,15 +165,21 @@ def extract_logits( outputs: Model outputs (can be tensor, DTensor, or object with logits attribute) Returns: - torch.Tensor: Logits tensor + torch.Tensor: Logits tensor with shape [batch, seq, vocab] """ if isinstance(outputs, (torch.Tensor, DTensor)): - # Custom models can output logits directly - return outputs + # Custom models can output logits directly. + # THD format: GPT-OSS returns [T, V], DeepseekV3 returns [1, T, V]. + # Normalize to [1, T, V] for downstream post-processors. + logits = outputs elif not hasattr(outputs, "logits"): - return model.lm_head(outputs.last_hidden_state) + logits = model.lm_head(outputs.last_hidden_state) else: - return outputs.logits + logits = outputs.logits + + if logits.ndim == 2: + logits = logits.unsqueeze(0) + return logits def apply_temperature_scaling( @@ -534,21 +578,40 @@ def __call__( global_valid_toks: torch.Tensor, sequence_dim: int = 1, ) -> tuple[torch.Tensor, dict[str, Any]]: - """Compute loss from logits. + """Compute loss from logits.""" + if processed_inputs.thd_batch is not None: + thd = processed_inputs.thd_batch + cu_seqlens_actual = thd.cu_seqlens_per_row[0] + # CP uses padded cu_seqlens for logit slicing (positions // cp_size). + # TE's THD partitioning matches Megatron's dual-chunk-swap. + cu_seqlens_padded = ( + thd.cu_seqlens_padded_per_row[0].to(dtype=torch.int32, device=logits.device) + if self.cp_size > 1 else cu_seqlens_actual + ) + cp_group = self.cp_mesh.get_group() if self.cp_size > 1 else None + prepare_fn = partial( + prepare_loss_input, sampling_params=self.sampling_params + ) + loss_fn = SequencePackingLossWrapper( + loss_fn=self.loss_fn, + prepare_fn=prepare_fn, + cu_seqlens_q=cu_seqlens_actual, + cu_seqlens_q_padded=cu_seqlens_padded, + context_parallel_group=cp_group, + ) + loss, loss_metrics = loss_fn( + logits, data_dict, global_valid_seqs, global_valid_toks, + ) + return loss, loss_metrics - Args: - logits: Model output logits - data_dict: Microbatch data - processed_inputs: Processed inputs - global_valid_seqs: Global valid sequence count - global_valid_toks: Global valid token count - sequence_dim: Sequence dimension + # Determine cu_seqlens source for seq packing (FA2 path) + if self.enable_seq_packing and processed_inputs.has_flash_attention: + cu_seqlens = processed_inputs.flash_attn_kwargs.cu_seqlens_q + else: + cu_seqlens = None - Returns: - Tuple of (loss, metrics) - """ - # Handle CP redistribution - if self.cp_size > 1: + # Handle CP redistribution (non-seqpack CP path) + if self.cp_size > 1 and cu_seqlens is None: _, data_dict = prepare_data_for_cp( data_dict, processed_inputs, self.cp_mesh, sequence_dim ) @@ -556,17 +619,17 @@ def __call__( logits, self.device_mesh, self.cp_mesh, sequence_dim ) - # Wrap prepare_loss_input with sampling_params prepare_loss_input_wrapped = partial( prepare_loss_input, sampling_params=self.sampling_params ) - # Wrap loss function for sequence packing if needed - if self.enable_seq_packing: + + if cu_seqlens is not None: + # FA2 seq packing loss path. loss_fn = SequencePackingLossWrapper( loss_fn=self.loss_fn, prepare_fn=prepare_loss_input_wrapped, - cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, - cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens, ) loss, loss_metrics = loss_fn( logits, @@ -646,7 +709,58 @@ def __call__( seq_len = processed_inputs.seq_len input_lengths = data_dict["input_lengths"] + if processed_inputs.thd_batch is not None: + thd = processed_inputs.thd_batch + cu_actual = thd.cu_seqlens_per_row[0] + cu_padded = ( + thd.cu_seqlens_padded_per_row[0] + if self.cp_size > 1 + else cu_actual + ) + cp_group = self.cp_mesh.get_group() if self.cp_size > 1 else None + cp_size_val = self.cp_size + + if logits.ndim == 2: + logits = logits.unsqueeze(0) + + # Only iterate over real sequences (not dummy-padded ones from PP) + n_seqs = min(len(cu_actual) - 1, original_batch_size) + unpacked_logprobs = torch.zeros( + (original_batch_size, original_seq_len), + dtype=logits.dtype, device=logits.device, + ) + for seq_idx in range(n_seqs): + actual_len = (cu_actual[seq_idx + 1] - cu_actual[seq_idx]).item() + padded_start = cu_padded[seq_idx].item() + padded_len = (cu_padded[seq_idx + 1] - cu_padded[seq_idx]).item() + + logit_start = padded_start // cp_size_val + logit_length = padded_len // cp_size_val + logit_slice = logits.narrow(1, logit_start, logit_length) + + seq_input_ids = data_dict["input_ids"][seq_idx : seq_idx + 1, :actual_len] + seq_logprobs = get_next_token_logprobs_from_logits( + input_ids=seq_input_ids, + next_token_logits=logit_slice, + context_parallel_group=cp_group, + sampling_params=self.sampling_params, + ) + unpacked_logprobs[seq_idx, 1 : 1 + seq_logprobs.shape[1]] = seq_logprobs[0] + + # Apply post-attention mask + for i, length in enumerate(input_lengths): + unpacked_logprobs[i, int(length):] = 0 + + if need_top_k_or_top_p_filtering(self.sampling_params): + mask = data_dict["token_mask"] * data_dict["sample_mask"].unsqueeze(-1) + unpacked_logprobs = mask_out_neg_inf_logprobs( + unpacked_logprobs, mask, "prev_logprobs" + ) + + return unpacked_logprobs + if self.cp_size > 1: + # Standard DTensor CP path (non-seqpack) seq_index_tensor = ( DTensor.from_local( processed_inputs.seq_index, @@ -672,7 +786,7 @@ def __call__( input_ids_dtensor, seq_index_tensor, chunk_size=self.logprob_chunk_size, - sampling_params=self.sampling_params, # top-k and top-p filtering + sampling_params=self.sampling_params, ) assert token_logprobs.shape[1] == seq_len - 1 @@ -696,14 +810,18 @@ def __call__( [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 ) - # Handle sequence packing unpacking or mask application - if self.enable_seq_packing: + # Handle sequence packing unpacking or mask application (FA2 path). + if self.enable_seq_packing and processed_inputs.has_flash_attention: + cu_seqlens = processed_inputs.flash_attn_kwargs.cu_seqlens_q + else: + cu_seqlens = None + + if cu_seqlens is not None: unpacked_logprobs = torch.zeros( (original_batch_size, original_seq_len), dtype=token_logprobs.dtype, device=token_logprobs.device, ) - cu_seqlens = processed_inputs.flash_attn_kwargs.cu_seqlens_q for i in range(original_batch_size): start = cu_seqlens[i].item() + 1 end = cu_seqlens[i + 1].item() @@ -895,8 +1013,13 @@ def __call__( full_logits = logits.to(torch.float32) vals, idx = torch.topk(full_logits, k=self.k, dim=-1) - # Handle sequence packing unpacking - if self.enable_seq_packing: + # Handle sequence packing unpacking (FA2 path). + if self.enable_seq_packing and processed_inputs.has_flash_attention: + cu_seqlens = processed_inputs.flash_attn_kwargs.cu_seqlens_q + else: + cu_seqlens = None + + if cu_seqlens is not None: # Unpack top-k results from packed format back to original batch format # vals: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] # idx: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] @@ -911,8 +1034,6 @@ def __call__( device=idx.device, ) - cu_seqlens = processed_inputs.flash_attn_kwargs.cu_seqlens_q - for i in range(original_batch_size): start = cu_seqlens[i].item() end = cu_seqlens[i + 1].item() diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index ec4c9e66bb..69e344d3f8 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -94,10 +94,11 @@ class DTensorConfig(TypedDict): env_vars: NotRequired[dict[str, str] | None] _v2: NotRequired[bool] # Distributed parallelism sizes - # data_parallel_size is derived from world_size / (tp * cp * ep) + # data_parallel_size is derived from world_size / (pp * tp * cp * ep) tensor_parallel_size: int context_parallel_size: int expert_parallel_size: NotRequired[int] + pipeline_parallel_size: NotRequired[int] # Distributed config options (mirrors Automodel's FSDP2Config) sequence_parallel: bool activation_checkpointing: bool diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 8acd808b11..fda3b51afe 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -143,6 +143,7 @@ def __init__( tp_size = config["dtensor_cfg"]["tensor_parallel_size"] cp_size = config["dtensor_cfg"]["context_parallel_size"] + pp_size = config["dtensor_cfg"].get("pipeline_parallel_size", 1) env_vars = config["dtensor_cfg"].get("env_vars", {}) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 2fa8a8e604..fb04a2ff8b 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -42,8 +42,15 @@ from nemo_rl.models.automodel.data import ( check_sequence_dim, get_microbatch_iterator, + install_thd_squeeze_hook, process_global_batch, ) +from nemo_rl.models.automodel.pipeline_parallel import ( + PPLossAdapter, + broadcast_tensors_from_last_pp_stage, + pp_forward_loop, + pp_train_forward_backward_loop, +) from nemo_rl.models.automodel.setup import ( setup_distributed, setup_model_and_optimizer, @@ -171,8 +178,8 @@ def get_train_context( """Create combined context manager for training with context parallel and autocast.""" with contextlib.ExitStack() as stack: context_parallel_ctx = None - if cp_size > 1: - # Create context parallel context + if cp_size > 1 and cp_buffers: + # Create DTensor CP context. Skip when cp_buffers is empty. context_parallel_ctx = create_context_parallel_ctx( cp_mesh=cp_mesh, cp_buffers=cp_buffers, @@ -262,6 +269,9 @@ def __init__( self.dp_size = distributed_context.dp_size self.tp_size = distributed_context.tp_size self.cp_size = distributed_context.cp_size + self.pp_size = distributed_context.pp_size + self.pp_mesh = distributed_context.pp_mesh + self.pp_enabled = self.pp_size > 1 # Initialize checkpoint manager now that distributed is set up self._init_checkpoint_manager( @@ -288,24 +298,31 @@ def __init__( optimizer_path=optimizer_path, ) - # Set instance attributes from model and optimizer state (tuple unpacking) + # Set instance attributes from model and optimizer state ( - self.model, - self.optimizer, - self.scheduler, + self.model_handle, + self.optimizers, + self.schedulers, self.is_hf_model, self.is_moe_model, - self._is_reward_model, # Note: using underscore prefix for internal naming + self._is_reward_model, self.model_class, self.model_config, self.peft_config, self.autocast_enabled, ) = model_and_optimizer_state + # Convenience aliases + self.model = self.model_handle.raw # raw AutoPipeline or nn.Module + self.has_first_stage = self.model_handle.has_first_stage + self.has_last_stage = self.model_handle.has_last_stage + # Initialize reference model if requested self.reference_model_state_dict = None if init_reference_model: - self.reference_model_state_dict = setup_reference_model_state(self.model) + self.reference_model_state_dict = setup_reference_model_state( + self.model_handle + ) # Set instance attributes from runtime config (tuple unpacking) ( @@ -352,11 +369,25 @@ def train( if eval_mode: ctx: AbstractContextManager[Any] = torch.no_grad() - self.model.eval() + self.model_handle.eval() else: ctx = nullcontext() - # Ensure model is in training mode - self.model.train() + self.model_handle.train() + + if self.pp_enabled: + return self._train_pp( + data=data, + loss_fn=loss_fn, + eval_mode=eval_mode, + local_gbs=local_gbs, + num_global_batches=num_global_batches, + sequence_dim=sequence_dim, + ctx=ctx, + optimizers_list=self.optimizers, + schedulers_list=self.schedulers, + ) + + # --- Non-PP path (existing) --- # Create loss post-processor loss_post_processor = LossPostProcessor( @@ -414,7 +445,8 @@ def on_microbatch_start(mb_idx): global_valid_seqs = gb_result["global_valid_seqs"] global_valid_toks = gb_result["global_valid_toks"] - self.optimizer.zero_grad() + for opt in self.optimizers: + opt.zero_grad() # Get microbatch iterator based on batching strategy processed_iterator, iterator_len = get_microbatch_iterator( @@ -424,6 +456,8 @@ def on_microbatch_start(mb_idx): self.dp_mesh, tokenizer=self.tokenizer, cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + device_mesh=self.device_mesh, ) # Use automodel_forward_backward for the training loop @@ -452,7 +486,7 @@ def on_microbatch_start(mb_idx): # Only process valid (non-dummy) batches for metrics if mb_idx < iterator_len: num_valid_samples = loss_metrics["num_valid_samples"] - loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["lr"] = self.optimizers[0].param_groups[0]["lr"] loss_metrics["global_valid_seqs"] = global_valid_seqs.item() loss_metrics["global_valid_toks"] = global_valid_toks.item() @@ -483,15 +517,18 @@ def on_microbatch_start(mb_idx): ) # Update parameters - self.optimizer.step() + for opt in self.optimizers: + opt.step() losses.append(torch.tensor(mb_losses).sum().item()) # release gradient memory before rollouts - self.optimizer.zero_grad() + for opt in self.optimizers: + opt.zero_grad() # increment scheduler after all batches in rollout are processed if not eval_mode: - self.scheduler.step() + for sched in self.schedulers: + sched.step() # dynamic batch and sequence dims causes alot of fragmentation, so clear # the memory allocator before moving on torch.cuda.empty_cache() @@ -507,6 +544,159 @@ def on_microbatch_start(mb_idx): return metrics + def _train_pp( + self, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + eval_mode: bool, + local_gbs: int, + num_global_batches: int, + sequence_dim: int, + ctx: AbstractContextManager[Any], + optimizers_list: list, + schedulers_list: list, + ) -> dict[str, Any]: + """PP training path using pp_train_forward_backward_loop.""" + loss_adapter = PPLossAdapter( + loss_fn=loss_fn, + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + dp_size=self.dp_size, + enable_seq_packing=self.enable_seq_packing, + sampling_params=self.sampling_params, + ) + + pp_batch_size = self.model.pp_batch_size + n_microbatches = self.model.info.schedule._n_microbatches + + # Install hook to squeeze batch dim for THD format in PP microbatches. + # The PP schedule splits [n_chunks, tokens_per_chunk] → [1, tokens_per_chunk] + # per microbatch, but custom models need [tokens_per_chunk] (1D) for THD. + thd_hooks: list = [] + if self.enable_seq_packing and not self.is_hf_model: + thd_hooks = install_thd_squeeze_hook(self.model_handle.parts) + + with ctx: + data = data.to("cuda") + losses = [] + all_mb_metrics = [] + grad_norm: Optional[float | torch.Tensor] = None + + for gb_idx in range(num_global_batches): + gb_result = process_global_batch( + data, + loss_fn, + self.dp_mesh.get_group(), + batch_idx=gb_idx, + batch_size=local_gbs, + ) + global_valid_seqs = gb_result["global_valid_seqs"] + global_valid_toks = gb_result["global_valid_toks"] + + for opt in optimizers_list: + opt.zero_grad() + + gb_total_loss, gb_metrics = pp_train_forward_backward_loop( + model=self.model, + model_parts=self.model_handle.parts, + batch=gb_result["batch"], + loss_adapter=loss_adapter, + pp_batch_size=pp_batch_size, + n_microbatches=n_microbatches, + has_last_stage=self.has_last_stage, + pp_mesh=self.pp_mesh, + enable_seq_packing=self.enable_seq_packing, + is_hf_model=self.is_hf_model, + tokenizer_eos_id=self.tokenizer.eos_token_id, + train_mb_tokens=self.cfg["sequence_packing"]["train_mb_tokens"], + cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + device_mesh=self.device_mesh, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + num_global_batches=num_global_batches, + forward_only=eval_mode, + ) + + # Enrich metrics with lr and global counts + if gb_metrics: + for m in gb_metrics: + lr = optimizers_list[0].param_groups[0]["lr"] + m["lr"] = lr + m["global_valid_seqs"] = global_valid_seqs.item() + m["global_valid_toks"] = global_valid_toks.item() + if m.get("num_valid_samples", 0) > 0: + all_mb_metrics.append(m) + + if not eval_mode: + # PP gradient scaling: each schedule.step() scales loss by + # dp_size*cp_size (cancelling FSDP averaging), matching the + # non-PP path. Set num_label_tokens = dp_group_size so the + # PP scaling in scale_grads_and_clip_grad_norm is a no-op + # (divides by 1). Gradient accumulation naturally sums across + # steps — this matches non-PP behavior where the full batch + # gradient is computed in one pass. + dp_group_size = self.dp_size * self.cp_size + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + self.model_handle.parts, + norm_type=2.0, + pp_enabled=True, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name="pp", + foreach=True, + num_label_tokens=dp_group_size, + dp_group_size=dp_group_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) + for opt in optimizers_list: + opt.step() + + # Broadcast loss from last PP stage so all ranks report correct value. + # Undo dp*cp scaling from PPLossAdapter (needed for FSDP grad averaging + # during training, but not for the reported loss). + if self.pp_size > 1: + loss_tensor = gb_total_loss.unsqueeze(0) + broadcasted = broadcast_tensors_from_last_pp_stage( + {"loss": loss_tensor if self.has_last_stage else None}, + self.pp_mesh, + self.has_last_stage, + ) + gb_total_loss = broadcasted["loss"].squeeze(0) / ( + self.dp_size * self.cp_size + ) + losses.append(gb_total_loss.item()) + + for opt in optimizers_list: + opt.zero_grad() + if not eval_mode: + for sched in schedulers_list: + sched.step() + torch.cuda.empty_cache() + + # Remove THD squeeze hooks + for h in thd_hooks: + h.remove() + + metrics = aggregate_training_statistics( + losses=losses, + all_mb_metrics=all_mb_metrics, + grad_norm=grad_norm, + dp_group=self.dp_mesh.get_group(), + dtype=self.dtype, + ) + return metrics + @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_logprobs") def get_logprobs( self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None @@ -529,8 +719,14 @@ def get_logprobs( else self.cfg["logprob_batch_size"] ) - # Validate sequence dimension + # Validate sequence dimension. Sync across ranks for PP (different ranks + # may have different max sequence lengths from GRPO generation). sequence_dim, seq_dim_size = check_sequence_dim(data) + if self.pp_enabled: + max_seq = torch.tensor([seq_dim_size], device="cuda") + torch.distributed.all_reduce(max_seq, op=torch.distributed.ReduceOp.MAX) + seq_dim_size = int(max_seq.item()) + return self._get_logprobs_pp(data, logprob_batch_size, seq_dim_size) all_log_probs = [] self.model.eval() @@ -556,6 +752,8 @@ def get_logprobs( self.dp_mesh, tokenizer=self.tokenizer, cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + device_mesh=self.device_mesh, ) for batch_idx, processed_mb in enumerate(processed_iterator): @@ -601,6 +799,42 @@ def get_logprobs( return return_data + def _get_logprobs_pp( + self, + data: BatchedDataDict[Any], + logprob_batch_size: int, + seq_dim_size: int, + ) -> BatchedDataDict[LogprobOutputSpec]: + """PP path for logprob computation using pp_forward_loop.""" + for mp in self.model_handle.parts: + mp.eval() + + post_processor = LogprobsPostProcessor( + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + enable_seq_packing=self.enable_seq_packing, + sampling_params=self.sampling_params, + ) + + with torch.no_grad(): + result = pp_forward_loop( + model=self.model, + data=data, + post_processor=post_processor, + pp_batch_size=self.model.pp_batch_size, + seq_dim_size=seq_dim_size, + pp_mesh=self.pp_mesh, + has_first_stage=self.has_first_stage, + has_last_stage=self.has_last_stage, + ) + + return_data = BatchedDataDict[LogprobOutputSpec]() + return_data["logprobs"] = result["logprobs"].cpu() + return return_data + @wrap_with_nvtx_name("dtensor_policy_worker_v2/score") def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: global_batch_size = min(self.cfg["batch_size"], data.size) @@ -624,6 +858,8 @@ def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: self.dp_mesh, tokenizer=self.tokenizer, cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + device_mesh=self.device_mesh, ) all_rm_scores = [] @@ -690,6 +926,9 @@ def get_topk_logits( # Validate sequence dimension sequence_dim, seq_dim_size = check_sequence_dim(data) + if self.pp_enabled: + return self._get_topk_logits_pp(data, k, topk_batch_size, seq_dim_size) + out_topk_vals = [] out_topk_idx = [] self.model.eval() @@ -715,6 +954,8 @@ def get_topk_logits( self.dp_mesh, tokenizer=self.tokenizer, cp_size=self.cp_size, + cp_mesh=self.cp_mesh, + device_mesh=self.device_mesh, ) for batch_idx, processed_mb in enumerate(processed_iterator): @@ -778,6 +1019,144 @@ def get_topk_logits( ).cpu() return ret + def _get_topk_logits_pp( + self, + data: BatchedDataDict[Any], + k: int, + topk_batch_size: int, + seq_dim_size: int, + ) -> BatchedDataDict[Any]: + """PP path for top-k logits using pp_forward_loop.""" + for mp in self.model_handle.parts: + mp.eval() + + post_processor = TopkLogitsPostProcessor( + cfg=self.cfg, + device_mesh=self.device_mesh, + cp_mesh=self.cp_mesh, + tp_mesh=self.tp_mesh, + cp_size=self.cp_size, + k=k, + enable_seq_packing=self.enable_seq_packing, + ) + + with torch.no_grad(): + result = pp_forward_loop( + model=self.model, + data=data, + post_processor=post_processor, + pp_batch_size=self.model.pp_batch_size, + seq_dim_size=seq_dim_size, + pp_mesh=self.pp_mesh, + has_first_stage=self.has_first_stage, + has_last_stage=self.has_last_stage, + ) + + ret = BatchedDataDict[Any]() + ret["topk_logits"] = result["topk_logits"].cpu() + ret["topk_indices"] = result["topk_indices"].cpu() + return ret + + def _all_params_generator(self): + """Yield (name, tensor) pairs for ALL model params, gathering across PP ranks. + + For non-PP: yields from the single model's state dict. + For PP: each rank has a subset of params. All PP ranks cooperate to + broadcast each param one at a time (like megatron bridge's + stream_weights_megatron_to_hf) so every rank can yield the full model + without holding all remote params in memory. + """ + if not self.pp_enabled: + yield from dtensor_params_generator(self.model, self.dtype) + return + + # Phase 1: build a lazy local param lookup (name → part index). + # We don't materialise all tensors yet to save memory. + local_names: set[str] = set() + for part in self.model_handle.parts: + for name in part.state_dict().keys(): + if not ( + name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight") + ): + local_names.add(name) + + # Phase 2: all-gather the name lists to build a global ordering. + pp_group = self.pp_mesh.get_group() + pp_ranks = torch.distributed.get_process_group_ranks(pp_group) + all_names_list: list[list[str]] = [None] * self.pp_size # type: ignore[list-item] + torch.distributed.all_gather_object( + all_names_list, sorted(local_names), group=pp_group + ) + + name_to_owner: dict[str, int] = {} + for pp_r, names in enumerate(all_names_list): + for n in names: + name_to_owner[n] = pp_r + + my_pp_rank = self.pp_mesh.get_local_rank() + + # Phase 3: iterate all params in sorted order. Owner broadcasts one + # tensor at a time; others allocate a receive buffer and receive. + for name in sorted(name_to_owner.keys()): + owner_pp_rank = name_to_owner[name] + owner_global_rank = pp_ranks[owner_pp_rank] + + if my_pp_rank == owner_pp_rank: + # Materialise and adapt the tensor (TP gather + HF name adapt) + tensor, owning_part = self._get_local_param_tensor(name) + tensor = tensor.cuda() + meta = [list(tensor.shape), str(tensor.dtype)] + torch.distributed.broadcast_object_list( + [meta], src=owner_global_rank, group=pp_group + ) + torch.distributed.broadcast( + tensor, src=owner_global_rank, group=pp_group + ) + for adapted_name, adapted_tensor in _maybe_adapt_tensor_to_hf( + owning_part, name, tensor + ): + yield ( + adapted_name, + adapted_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + meta_list = [None] + torch.distributed.broadcast_object_list( + meta_list, src=owner_global_rank, group=pp_group + ) + shape, dtype_str = meta_list[0] + dtype = getattr(torch, dtype_str.replace("torch.", "")) + buf = torch.empty(shape, dtype=dtype, device="cuda") + torch.distributed.broadcast(buf, src=owner_global_rank, group=pp_group) + for adapted_name, adapted_tensor in _maybe_adapt_tensor_to_hf( + self.model_handle.parts[0], name, buf + ): + yield ( + adapted_name, + adapted_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + + def _get_local_param_tensor(self, name: str) -> tuple[torch.Tensor, nn.Module]: + """Get a single param tensor by name from local model parts, gathering DTensor. + + Returns: + Tuple of (tensor, owning_part) so callers can pass the correct + model part to state_dict adapters. + """ + for part in self.model_handle.parts: + sd = part.state_dict() + if name in sd: + tensor = sd[name] + gathered = ( + tensor.full_tensor() if isinstance(tensor, DTensor) else tensor + ) + return gathered, part + raise KeyError(f"Param {name} not found in any local model part") + + def _model_state_dict_items(self): + """Iterate state_dict items (unified for PP and non-PP via ModelHandle).""" + return self.model_handle.state_dict_items() + @contextmanager def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. @@ -790,19 +1169,15 @@ def use_reference_model(self) -> Generator[None, None, None]: with torch.no_grad(): # Save train model state_dict curr_state_dict = get_cpu_state_dict( - self.model.state_dict().items(), pin_memory=True + self._model_state_dict_items(), pin_memory=True ) # Swap reference model state_dict to self.model - for k, v in self.model.state_dict().items(): + for k, v in self._model_state_dict_items(): val = to_local_if_dtensor(v) val.copy_(self.reference_model_state_dict[k]) # Temporarily disable top-k/top-p filtering for reference policy logprobs. - # The reference policy has different weights, so its top-k/top-p set is - # inherently different from the current policy. Using filtered logprobs - # would cause -inf mismatches that cannot be resolved by masking. - # Note: We keep temperature scaling since it was applied to prev_logprobs. saved_sampling_params = self.sampling_params if saved_sampling_params is not None: self.sampling_params = TrainingSamplingParams( @@ -813,15 +1188,13 @@ def use_reference_model(self) -> Generator[None, None, None]: else: self.sampling_params = None - # - self.model is the original reference_model, now on CUDA - # - curr_state_dict is the train model, now on CPU yield # Restore sampling_params self.sampling_params = saved_sampling_params # Restore train model state_dict - for k, v in self.model.state_dict().items(): + for k, v in self._model_state_dict_items(): val = to_local_if_dtensor(v) val.copy_(curr_state_dict[k]) @@ -835,34 +1208,42 @@ def _add_noise_to_weights(self) -> None: torch.cuda.synchronize() def return_state_dict(self): - return self.model.state_dict() + return self.model_handle.state_dict() def return_model_config(self) -> dict[str, Any]: - """Return the model configuration as a dictionary. - - Returns: - dict: Model configuration dictionary - """ - return self.model.config + """Return the model configuration as a dictionary.""" + return self.model_handle.config @torch.no_grad() def prepare_refit_info(self) -> Optional[dict[str, Any]]: - """Prepare state dict metadata for weight refitting and IPC streaming.""" - state_dict_info = {} - for name, tensor in self.model.state_dict().items(): - if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): - continue - full_tensor = ( - tensor.full_tensor() if isinstance(tensor, DTensor) else tensor - ) - # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective - adapted_fqn_tensors = _maybe_adapt_tensor_to_hf( - self.model, name, full_tensor - ) - for adapted_fqn, adapted_tensor in adapted_fqn_tensors: - state_dict_info[adapted_fqn] = (adapted_tensor.shape, self.dtype) + """Prepare state dict metadata for weight refitting and IPC streaming. - return state_dict_info + For PP: each rank only has a subset of params. All-gather param info + across PP ranks so every rank returns the full model's metadata. + """ + local_info = {} + for part in self.model_handle.parts: + for name, tensor in part.state_dict().items(): + if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): + continue + full_tensor = ( + tensor.full_tensor() if isinstance(tensor, DTensor) else tensor + ) + adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(part, name, full_tensor) + for adapted_fqn, adapted_tensor in adapted_fqn_tensors: + local_info[adapted_fqn] = (adapted_tensor.shape, self.dtype) + + if self.pp_enabled: + # All-gather param info across PP ranks so all workers return full metadata + pp_group = self.pp_mesh.get_group() + all_infos = [None] * self.pp_size + torch.distributed.all_gather_object(all_infos, local_info, group=pp_group) + state_dict_info = {} + for info in all_infos: + state_dict_info.update(info) + return state_dict_info + + return local_info @torch.no_grad() def calibrate_qkv_fp8_scales( @@ -898,9 +1279,8 @@ def stream_weights_via_ipc_zmq( from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl - # Use the shared implementation stream_weights_via_ipc_zmq_impl( - params_generator=dtensor_params_generator(self.model, self.dtype), + params_generator=self._all_params_generator(), buffer_size_bytes=buffer_size_bytes, zmq_socket=self.zmq_socket, rank=self.rank, @@ -976,7 +1356,7 @@ def broadcast_weights_for_collective( dtensor_post_iter_func = lambda x: x[1] packed_broadcast_producer( - iterator=dtensor_params_generator(self.model, self.dtype), + iterator=self._all_params_generator(), group=self.model_update_group, src=0, post_iter_func=dtensor_post_iter_func, @@ -993,13 +1373,12 @@ def prepare_for_lp_inference(self) -> None: if not self.cpu_offload: self.move_to_cuda(self.model) else: - self.model = self.move_buffer_to_device(self.model, "cuda") - - self.model.eval() + self.model_handle.move_buffers_to("cuda") + self.model_handle.eval() # offload optimizer to cpu torch.randn(1).cuda() # wake up torch allocator - if self.optimizer is not None and self.offload_optimizer_for_logprob: + if self.optimizers is not None and self.offload_optimizer_for_logprob: self.move_optimizer_to_device("cpu") gc.collect() @@ -1007,19 +1386,15 @@ def prepare_for_lp_inference(self) -> None: @wrap_with_nvtx_name("dtensor_policy_worker_v2/prepare_for_training") def prepare_for_training(self, *args, **kwargs) -> None: - # onload models and optimizer state to cuda if not self.cpu_offload: self.move_to_cuda(self.model) else: - # when cpu offload is enabled, the buffers do not get moved - # to cuda automatically, so we need to do that manually - self.model = self.move_buffer_to_device(self.model, "cuda") - - self.model.train() + self.model_handle.move_buffers_to("cuda") + self.model_handle.train() # Move optimizer state to CUDA if it exists # colocated generation will always offload optimizer to cuda before refit if ( - self.optimizer is not None + self.optimizers is not None and not self.cpu_offload and (self.offload_optimizer_for_logprob or self.is_generation_colocated) ): @@ -1032,7 +1407,7 @@ def prepare_for_training(self, *args, **kwargs) -> None: def offload_before_refit(self) -> None: """Offload the optimizer to the CPU.""" torch.randn(1).cuda() # wake up torch allocator - if self.optimizer is not None: + if self.optimizers is not None: self.move_optimizer_to_device("cpu") gc.collect() @@ -1043,7 +1418,7 @@ def offload_before_refit(self) -> None: def offload_after_refit(self) -> None: """Offload as much as possible on the CPU.""" self.model = self.move_to_cpu(self.model) - self.model.eval() + self.model_handle.eval() torch.randn(1).cuda() # wake up torch allocator self.offload_before_refit() # rerun the old offload function @@ -1055,22 +1430,22 @@ def offload_after_refit(self) -> None: ) def move_optimizer_to_device(self, device: str | torch.device) -> None: - for state in self.optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, (DTensor, torch.Tensor)): - state[k] = v.to(device) + opts = self.optimizers + for opt in opts: + for state in opt.state.values(): + for k, v in state.items(): + if isinstance(v, (DTensor, torch.Tensor)): + state[k] = v.to(device) def move_to_device(self, model: nn.Module, device: str | torch.device) -> nn.Module: - model = self.move_buffer_to_device(model, device) - return model.to(device) + self.model_handle.move_buffers_to(device) + self.model_handle.to(device) + return model def move_buffer_to_device( self, model: nn.Module, device: str | torch.device ) -> nn.Module: - # FSDP modules do not move buffers to the device automatically - for v in model.buffers(): - torch.utils.swap_tensors(v, v.to(device)) - + self.model_handle.move_buffers_to(device) return model def move_to_cuda(self, model: torch.nn.Module) -> torch.nn.Module: @@ -1095,13 +1470,29 @@ def save_checkpoint( """Save a checkpoint of the model. the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + For PP, passes model parts list and optimizer/scheduler lists so the + automodel Checkpointer can save per-stage state. """ + # For PP, pass model parts list; for non-PP, pass single model. + # The automodel Checkpointer's ModelState/OptimizerState wrappers + # handle both single nn.Module and lists transparently. + model_for_ckpt = self.model_handle.parts if self.pp_enabled else self.model + optimizer_for_ckpt = ( + self.optimizers + if self.pp_enabled + else (self.optimizers[0] if self.optimizers else None) + ) + scheduler_for_ckpt = ( + self.schedulers + if self.pp_enabled + else (self.schedulers[0] if self.schedulers else None) + ) self.checkpoint_manager.save_checkpoint( - model=self.model, + model=model_for_ckpt, weights_path=weights_path, - optimizer=self.optimizer, + optimizer=optimizer_for_ckpt, optimizer_path=optimizer_path, - scheduler=self.scheduler, + scheduler=scheduler_for_ckpt, tokenizer=self.tokenizer if tokenizer_path else None, tokenizer_path=tokenizer_path, checkpointing_cfg=checkpointing_cfg, @@ -1115,12 +1506,23 @@ def load_checkpoint( optimizer_path: Optional[str] = None, ) -> None: """Load a checkpoint into the model using Automodel Checkpointer.""" + model_for_ckpt = self.model_handle.parts if self.pp_enabled else self.model + optimizer_for_ckpt = ( + self.optimizers + if self.pp_enabled + else (self.optimizers[0] if self.optimizers else None) + ) + scheduler_for_ckpt = ( + self.schedulers + if self.pp_enabled + else (self.schedulers[0] if self.schedulers else None) + ) self.checkpoint_manager.load_checkpoint( - model=self.model, + model=model_for_ckpt, weights_path=weights_path, - optimizer=self.optimizer, + optimizer=optimizer_for_ckpt, optimizer_path=optimizer_path, - scheduler=self.scheduler, + scheduler=scheduler_for_ckpt, ) def _init_checkpoint_manager( @@ -1142,6 +1544,7 @@ def _init_checkpoint_manager( dp_mesh=self.dp_mesh, tp_mesh=self.tp_mesh, moe_mesh=self.moe_mesh, + pp_mesh=getattr(self, "pp_mesh", None), ) self.checkpoint_manager.init_checkpointer( config_updates=config_updates, diff --git a/tests/test_suites/llm/dpo-llama3.1-8b-tulu3-1n8g-pp2.sh b/tests/test_suites/llm/dpo-llama3.1-8b-tulu3-1n8g-pp2.sh new file mode 100755 index 0000000000..5205ae90a5 --- /dev/null +++ b/tests/test_suites/llm/dpo-llama3.1-8b-tulu3-1n8g-pp2.sh @@ -0,0 +1,46 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=150 +MAX_STEPS=150 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=60 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_dpo.py \ + --config $CONFIG_PATH \ + dpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +# Slightly relaxed thresholds vs non-PP baseline (PP=2 vs PP=1 gives ~1% variance) +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/sft_loss"]["1"] < 0.00001' \ + 'data["train/sft_loss"]["150"] < 0.00001' \ + 'data["train/preference_loss"]["1"] > 0.6930' \ + 'data["train/preference_loss"]["1"] < 0.6932' \ + 'data["train/preference_loss"]["150"] < 0.68' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack.sh b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack.sh new file mode 100755 index 0000000000..a158473c9f --- /dev/null +++ b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack.sh @@ -0,0 +1,45 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=30 +MAX_STEPS=30 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/gen_kl_error"]) < 0.001' \ + 'data["train/gen_kl_error"]["30"] < 0.001 ' \ + 'data["train/reward"]["30"] > 0.4' \ + 'data["train/grad_norm"]["30"] < 0.2' \ + 'data["train/grad_norm"]["30"] > 0.1' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack.sh b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack.sh new file mode 100755 index 0000000000..a158473c9f --- /dev/null +++ b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack.sh @@ -0,0 +1,45 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=30 +MAX_STEPS=30 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/gen_kl_error"]) < 0.001' \ + 'data["train/gen_kl_error"]["30"] < 0.001 ' \ + 'data["train/reward"]["30"] > 0.4' \ + 'data["train/grad_norm"]["30"] < 0.2' \ + 'data["train/grad_norm"]["30"] > 0.1' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.sh b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.sh new file mode 100755 index 0000000000..014f7ecd9b --- /dev/null +++ b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.sh @@ -0,0 +1,50 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=30 +MAX_STEPS=30 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# PP requires automodel with update_seq_len (PR #1689). +WORKER_VENV=/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2 +$WORKER_VENV/bin/python3 -c "from nemo_automodel.components.distributed.pipelining.autopipeline import AutoPipeline; assert hasattr(AutoPipeline, 'update_seq_len')" 2>/dev/null \ + || $WORKER_VENV/bin/pip install -e $PROJECT_ROOT/3rdparty/Automodel-workspace/Automodel --no-deps -q + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Slightly relaxed thresholds for PP +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/gen_kl_error"]) < 0.001' \ + 'data["train/gen_kl_error"]["30"] < 0.001 ' \ + 'data["train/reward"]["30"] > 0.3' \ + 'data["train/grad_norm"]["30"] < 0.25' \ + 'data["train/grad_norm"]["30"] > 0.08' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4.sh b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4.sh new file mode 100755 index 0000000000..5cbdfddbb4 --- /dev/null +++ b/tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-pp2ep4.sh @@ -0,0 +1,52 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=30 +MAX_STEPS=30 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# PP requires automodel with update_seq_len (PR #1689). +# Install from submodule into the DTensor v2 worker venv if the container's version is older. +WORKER_VENV=/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2 +$WORKER_VENV/bin/python3 -c "from nemo_automodel.components.distributed.pipelining.autopipeline import AutoPipeline; assert hasattr(AutoPipeline, 'update_seq_len')" 2>/dev/null \ + || $WORKER_VENV/bin/pip install -e $PROJECT_ROOT/3rdparty/Automodel-workspace/Automodel --no-deps -q + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +# Slightly relaxed thresholds vs non-PP baseline +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/gen_kl_error"]) < 0.001' \ + 'data["train/gen_kl_error"]["30"] < 0.001 ' \ + 'data["train/reward"]["30"] > 0.3' \ + 'data["train/grad_norm"]["30"] < 0.25' \ + 'data["train/grad_norm"]["30"] > 0.08' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel.sh b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel.sh new file mode 100755 index 0000000000..2b37048106 --- /dev/null +++ b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 7.0' \ + 'data["train/loss"]["50"] < 0.4' \ + 'data["train/grad_norm"]["50"] < 17.5' \ + 'data["train/grad_norm"]["50"] > 10.0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-pp2ep4-automodel.sh b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-pp2ep4-automodel.sh new file mode 100755 index 0000000000..524866800d --- /dev/null +++ b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-pp2ep4-automodel.sh @@ -0,0 +1,51 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=60 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# PP requires automodel with update_seq_len (PR #1689). +# Install from submodule into the DTensor v2 worker venv if the container's version is older. +WORKER_VENV=/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2 +$WORKER_VENV/bin/python3 -c "from nemo_automodel.components.distributed.pipelining.autopipeline import AutoPipeline; assert hasattr(AutoPipeline, 'update_seq_len')" 2>/dev/null \ + || $WORKER_VENV/bin/pip install -e $PROJECT_ROOT/3rdparty/Automodel-workspace/Automodel --no-deps -q + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +# Slightly relaxed thresholds vs non-PP baseline (PP=2 vs PP=1 gives ~1% variance) +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 7.5' \ + 'data["train/loss"]["50"] < 0.5' \ + 'data["train/grad_norm"]["50"] < 20.0' \ + 'data["train/grad_norm"]["50"] > 8.0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.sh b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.sh new file mode 100755 index 0000000000..5d4c9e0653 --- /dev/null +++ b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.sh @@ -0,0 +1,50 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=60 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# PP requires automodel with update_seq_len (PR #1689). +# Install from submodule into the DTensor v2 worker venv if the container's version is older. +WORKER_VENV=/opt/ray_venvs/nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2 +$WORKER_VENV/bin/python3 -c "from nemo_automodel.components.distributed.pipelining.autopipeline import AutoPipeline; assert hasattr(AutoPipeline, 'update_seq_len')" 2>/dev/null \ + || $WORKER_VENV/bin/pip install -e $PROJECT_ROOT/3rdparty/Automodel-workspace/Automodel --no-deps -q + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Slightly relaxed thresholds vs non-PP baseline (PP=2 gives ~1% variance) +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 7.5' \ + 'data["train/loss"]["50"] < 0.5' \ + 'data["train/grad_norm"]["50"] < 20.0' \ + 'data["train/grad_norm"]["50"] > 8.0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/llm/sft-moonlight-16b-1n8g-cp2ep4-seqpack-automodel.sh b/tests/test_suites/llm/sft-moonlight-16b-1n8g-cp2ep4-seqpack-automodel.sh new file mode 100755 index 0000000000..1eb338a7a9 --- /dev/null +++ b/tests/test_suites/llm/sft-moonlight-16b-1n8g-cp2ep4-seqpack-automodel.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 4.0' \ + 'data["train/loss"]["50"] < 2.0' \ + 'data["train/grad_norm"]["50"] < 20.0' \ + 'data["train/grad_norm"]["50"] > 3.0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/unit/models/automodel/test_pipeline_parallel.py b/tests/unit/models/automodel/test_pipeline_parallel.py new file mode 100644 index 0000000000..d019ceb6d5 --- /dev/null +++ b/tests/unit/models/automodel/test_pipeline_parallel.py @@ -0,0 +1,389 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for pipeline_parallel.py utilities.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +try: + import nemo_automodel # noqa: F401 +except ImportError: + pytest.skip("nemo_automodel not available", allow_module_level=True) + +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType +from nemo_rl.models.automodel.pipeline_parallel import ( + PPLogitsCapturer, + PPLossAdapter, + _reset_pp_schedule_state, + pad_batch_for_pp, +) + +# ===================== +# Fixtures +# ===================== + + +@pytest.fixture +def mock_loss_fn(): + loss_fn = MagicMock(spec=LossFunction) + loss_fn.return_value = (torch.tensor(0.5), {"loss": 0.5, "num_valid_samples": 4}) + loss_fn.input_type = LossInputType.LOGIT + return loss_fn + + +@pytest.fixture +def mock_device_mesh(): + mesh = MagicMock() + mesh.get_group.return_value = MagicMock() + mesh.__getitem__ = MagicMock(return_value=mesh) + return mesh + + +@pytest.fixture +def mock_cp_mesh(): + mesh = MagicMock() + mesh.get_group.return_value = MagicMock() + return mesh + + +@pytest.fixture +def mock_tp_mesh(): + mesh = MagicMock() + mesh.get_group.return_value = MagicMock() + return mesh + + +@pytest.fixture +def base_cfg(): + return { + "dtensor_cfg": {"sequence_parallel": False}, + "sequence_packing": {"train_mb_tokens": 256}, + "generation": {"temperature": 1.0, "top_p": 1.0, "top_k": None}, + } + + +# ===================== +# Test PPLogitsCapturer +# ===================== +@pytest.mark.automodel +class TestPPLogitsCapturer: + def test_captures_logits(self): + capturer = PPLogitsCapturer() + mock_output = MagicMock() + mock_output.logits = torch.randn(4, 64, 32000) + target = torch.randint(0, 32000, (4, 64)) + + result = capturer(mock_output, target) + + assert len(capturer.captured_logits) == 1 + assert torch.equal(capturer.captured_logits[0], mock_output.logits) + assert result.item() == 0.0 + + def test_captures_multiple_microbatches(self): + capturer = PPLogitsCapturer() + for i in range(3): + mock_output = MagicMock() + mock_output.logits = torch.randn(2, 32, 1000) + capturer(mock_output, torch.zeros(2, 32, dtype=torch.long)) + + assert len(capturer.captured_logits) == 3 + + def test_reset_clears_state(self): + capturer = PPLogitsCapturer() + mock_output = MagicMock() + mock_output.logits = torch.randn(2, 32, 1000) + capturer(mock_output, torch.zeros(2, 32, dtype=torch.long)) + + capturer.reset() + assert len(capturer.captured_logits) == 0 + + def test_captures_raw_tensor(self): + """When output is a tensor (not an object with .logits), capture directly.""" + capturer = PPLogitsCapturer() + raw_logits = torch.randn(4, 64, 32000) + target = torch.randint(0, 32000, (4, 64)) + + capturer(raw_logits, target) + + assert len(capturer.captured_logits) == 1 + assert torch.equal(capturer.captured_logits[0], raw_logits) + + def test_captured_logits_are_detached(self): + capturer = PPLogitsCapturer() + logits = torch.randn(2, 4, 100, requires_grad=True) + capturer(logits, torch.zeros(2, 4, dtype=torch.long)) + + assert not capturer.captured_logits[0].requires_grad + + +# ===================== +# Test PPLossAdapter +# ===================== +@pytest.mark.automodel +class TestPPLossAdapter: + def test_set_microbatches_splits_data( + self, mock_loss_fn, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh + ): + adapter = PPLossAdapter( + loss_fn=mock_loss_fn, + cfg=base_cfg, + device_mesh=mock_device_mesh, + cp_mesh=mock_cp_mesh, + tp_mesh=mock_tp_mesh, + cp_size=1, + dp_size=1, + ) + + data = { + "input_ids": torch.randint(0, 1000, (8, 64)), + "input_lengths": torch.full((8,), 64), + "sample_mask": torch.ones(8), + } + n_microbatches = 2 + + adapter.set_microbatches( + data, + n_microbatches, + global_valid_seqs=torch.tensor(8), + global_valid_toks=torch.tensor(512), + ) + + assert len(adapter._microbatches) == 2 + assert adapter._microbatches[0]["input_ids"].shape[0] == 4 + assert adapter._microbatches[1]["input_ids"].shape[0] == 4 + + def test_reset_clears_state( + self, mock_loss_fn, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh + ): + adapter = PPLossAdapter( + loss_fn=mock_loss_fn, + cfg=base_cfg, + device_mesh=mock_device_mesh, + cp_mesh=mock_cp_mesh, + tp_mesh=mock_tp_mesh, + cp_size=1, + dp_size=1, + ) + + data = {"input_ids": torch.randint(0, 1000, (4, 32))} + adapter.set_microbatches(data, 1, torch.tensor(4), torch.tensor(128)) + + adapter.reset() + assert adapter._call_idx == 0 + assert adapter._all_metrics == [] + + def test_call_scales_loss_by_dp_cp( + self, mock_loss_fn, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh + ): + """Loss should be scaled by dp_size * cp_size to cancel FSDP averaging.""" + dp_size = 4 + cp_size = 2 + adapter = PPLossAdapter( + loss_fn=mock_loss_fn, + cfg=base_cfg, + device_mesh=mock_device_mesh, + cp_mesh=mock_cp_mesh, + tp_mesh=mock_tp_mesh, + cp_size=cp_size, + dp_size=dp_size, + ) + + data = { + "input_ids": torch.randint(0, 1000, (4, 32)), + "input_lengths": torch.full((4,), 32), + "sample_mask": torch.ones(4), + } + adapter.set_microbatches(data, 1, torch.tensor(4), torch.tensor(128)) + + logits = torch.randn(4, 32, 32000) + target = torch.randint(0, 32000, (4, 32)) + + mock_output = MagicMock() + mock_output.logits = logits + + result = adapter(mock_output, target) + + base_loss = 0.5 # from mock_loss_fn + expected_scale = dp_size * cp_size + assert abs(result.item() - base_loss * expected_scale) < 1e-5 + + def test_call_increments_index( + self, mock_loss_fn, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh + ): + adapter = PPLossAdapter( + loss_fn=mock_loss_fn, + cfg=base_cfg, + device_mesh=mock_device_mesh, + cp_mesh=mock_cp_mesh, + tp_mesh=mock_tp_mesh, + cp_size=1, + dp_size=1, + ) + + data = { + "input_ids": torch.randint(0, 1000, (4, 32)), + "input_lengths": torch.full((4,), 32), + "sample_mask": torch.ones(4), + } + adapter.set_microbatches(data, 2, torch.tensor(4), torch.tensor(128)) + + logits = torch.randn(2, 32, 32000) + target = torch.randint(0, 32000, (2, 32)) + mock_output = MagicMock() + mock_output.logits = logits + + adapter(mock_output, target) + assert adapter._call_idx == 1 + + adapter(mock_output, target) + assert adapter._call_idx == 2 + + def test_unsqueezes_2d_logits( + self, mock_loss_fn, base_cfg, mock_device_mesh, mock_cp_mesh, mock_tp_mesh + ): + """THD format produces 2D logits [total_tokens, vocab] — adapter should unsqueeze to 3D.""" + adapter = PPLossAdapter( + loss_fn=mock_loss_fn, + cfg=base_cfg, + device_mesh=mock_device_mesh, + cp_mesh=mock_cp_mesh, + tp_mesh=mock_tp_mesh, + cp_size=1, + dp_size=1, + ) + + # THD: 1 packed row of 128 tokens + data = { + "input_ids": torch.randint(0, 1000, (1, 128)), + "input_lengths": torch.tensor([128]), + "sample_mask": torch.ones(1), + } + adapter.set_microbatches(data, 1, torch.tensor(1), torch.tensor(128)) + + # 2D logits (THD format: [total_tokens, vocab]) + logits_2d = torch.randn(128, 32000) + # Target must match the unsqueezed batch dim + target = torch.randint(0, 32000, (1, 128)) + mock_output = MagicMock() + mock_output.logits = logits_2d + + # Should not crash — adapter handles 2D by unsqueezing to [1, 128, vocab] + result = adapter(mock_output, target) + assert isinstance(result, torch.Tensor) + + +# ===================== +# Test pad_batch_for_pp +# ===================== +@pytest.mark.automodel +class TestPadBatchForPP: + def test_no_padding_needed(self): + input_ids = torch.randint(0, 1000, (8, 64)) + padded, actual = pad_batch_for_pp(input_ids, pp_batch_size=8) + + assert actual == 8 + assert padded.shape == (8, 64) + assert torch.equal(padded, input_ids) + + def test_padding_added(self): + input_ids = torch.randint(0, 1000, (3, 64)) + padded, actual = pad_batch_for_pp(input_ids, pp_batch_size=8) + + assert actual == 3 + assert padded.shape == (8, 64) + # Original rows preserved + assert torch.equal(padded[:3], input_ids) + # Padding rows are zeros + assert (padded[3:] == 0).all() + + def test_single_sample(self): + input_ids = torch.randint(0, 1000, (1, 32)) + padded, actual = pad_batch_for_pp(input_ids, pp_batch_size=4) + + assert actual == 1 + assert padded.shape == (4, 32) + + def test_preserves_dtype(self): + input_ids = torch.randint(0, 1000, (2, 16), dtype=torch.int32) + padded, _ = pad_batch_for_pp(input_ids, pp_batch_size=4) + + assert padded.dtype == torch.int32 + + +# ===================== +# Test _reset_pp_schedule_state +# ===================== +@pytest.mark.automodel +class TestResetPPScheduleState: + def test_raises_without_update_seq_len(self): + model = MagicMock(spec=[]) # No update_seq_len attribute + + with pytest.raises(RuntimeError, match="update_seq_len.*not found"): + _reset_pp_schedule_state(model, seq_len=64) + + def test_calls_update_seq_len_for_bshd(self): + model = MagicMock() + model._pp_current_seq_len = None + + _reset_pp_schedule_state(model, seq_len=64) + + model.update_seq_len.assert_called_once_with(64) + + def test_calls_reset_thd_for_seqpack(self): + model = MagicMock() + model._pp_current_seq_len = None + + with patch( + "nemo_rl.models.automodel.pipeline_parallel.reset_pp_stage_shapes_for_thd" + ) as mock_reset_thd: + _reset_pp_schedule_state( + model, seq_len=128, seqpack=True, is_hf_model=False + ) + + mock_reset_thd.assert_called_once_with(model, 128) + # _pp_current_seq_len should be cleared for THD + # (THD always needs fresh shapes) + + def test_hf_model_uses_update_seq_len_even_with_seqpack(self): + """HF models don't use THD format, so update_seq_len is used.""" + model = MagicMock() + model._pp_current_seq_len = None + + _reset_pp_schedule_state(model, seq_len=64, seqpack=True, is_hf_model=True) + + model.update_seq_len.assert_called_once_with(64) + + def test_force_clears_cached_seq_len(self): + """force=True clears _pp_current_seq_len so update_seq_len re-initializes.""" + model = MagicMock() + model._pp_current_seq_len = 64 + + _reset_pp_schedule_state(model, seq_len=64, force=True) + + # force=True should clear the cache, causing update_seq_len to be called + # even though seq_len matches the cached value + model.update_seq_len.assert_called_once_with(64) + + def test_update_seq_len_skips_when_unchanged(self): + """Without force, update_seq_len handles its own skip logic.""" + model = MagicMock() + model._pp_current_seq_len = 64 + + _reset_pp_schedule_state(model, seq_len=64) + + # update_seq_len is still called — it has its own internal skip + model.update_seq_len.assert_called_once_with(64)