Skip to content

feat: add PP + CP + seqpack support for automodel backend#2242

Draft
hemildesai wants to merge 9 commits intomainfrom
hemil/automodel-pp-cp
Draft

feat: add PP + CP + seqpack support for automodel backend#2242
hemildesai wants to merge 9 commits intomainfrom
hemil/automodel-pp-cp

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

Summary

Adds pipeline parallelism (PP), context parallelism (CP), and sequence packing (THD format) support for the automodel backend, with all combinations working end-to-end for SFT and GRPO.

Key changes

  • Pipeline parallelism: pipeline_parallel.py module with PPLossAdapter, pp_forward_backward(), PP stage shape management, and broadcast helpers. Supports 1F1B and Interleaved1F1B schedules via automodel's AutoPipeline.
  • Sequence packing (THD format): pack_for_thd() in data.py packs variable-length sequences into TE's THD attention format. Works with all custom automodel models (Moonlight, GPT-OSS, Qwen3-MoE). Includes CP padding (2*cp_size) for dual-chunk-swap compatibility.
  • CP + seqpack: Uses SequencePackingLossWrapper with context_parallel_group for loss computation across CP ranks. Logprobs use from_parallel_logits_to_logprobs with a single-rank TP group + CP group.
  • PP + seqpack: THD squeeze hooks, dummy batch padding, reset_pp_stage_shapes_for_thd() for dynamic token counts, prepare_pp_seqpack_batch() for PP-specific packing.
  • BSHD CP (no seqpack): PyTorch DTensor context_parallel context manager path verified working for both SFT and GRPO.
  • Refactored PP methods: Extracted pp_forward_loop(), pp_forward_with_post_processing(), and pp_train_forward_backward_loop() into pipeline_parallel.py, reducing _get_logprobs_pp from ~104 to ~20 lines, _get_topk_logits_pp from ~71 to ~20 lines, and _train_pp from ~260 to ~80 lines.
  • Bug fixes: Missing pad_batch_for_pp in topk PP path, inconsistent _pp_current_seq_len reset, input_lengths using unclamped slice indices.

Test plan

Unit tests (20 new tests)

  • tests/unit/models/automodel/test_pipeline_parallel.py — PPLogitsCapturer, PPLossAdapter, pad_batch_for_pp, _reset_pp_schedule_state

SFT verification (loss match within ~0.3% of baseline)

Config Steps Expected Loss (step 1) Status
moonlight EP=8 seqpack (baseline) 3 2.6264 Verified
moonlight CP=2 EP=4 seqpack 3 ~2.62 Verified
moonlight CP=2 EP=4 BSHD (no seqpack) 3 ~2.62 Verified
moonlight PP=2 CP=2 EP=4 seqpack 3 ~2.62 Verified
moonlight EP=8 BSHD baseline (no seqpack) 3 2.6264 Verified
gpt-oss PP=2 EP=4 seqpack 3 N/A Verified (runs)

GRPO verification (exact match on loss, gen_kl_error, reward)

Config Steps Loss gen_kl_error grad_norm Status
moonlight EP=8 baseline 1 -0.0111 0.0003 0.17825 Verified
moonlight CP=2 EP=4 BSHD 1 -0.0111 0.0003 Verified (exact)
moonlight PP=2 EP=4 1 -0.0111 0.0003 0.17816 Verified
moonlight CP=2 EP=4 seqpack 1 -0.0111 0.0003 Verified

Nightly test scripts

  • sft-gpt-oss-20b-1n8g-pp2ep4-automodel.sh
  • sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel.sh
  • grpo-moonlight-16b-automodel-1n8g-pp2ep4.sh
  • grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack.sh

🤖 Generated with Claude Code

hemildesai and others added 7 commits April 6, 2026 21:50
Add PP support to the automodel (DTensor) training backend, enabling
pipeline-parallel training for SFT and GRPO with custom nemo_automodel
models (GPT-OSS, Moonlight/DeepseekV3, Qwen3 MoE).

Key changes:
- Add pipeline_parallel_size to DTensorConfig and DistributedContext
- Resolve PipelineConfig from automodel_kwargs, detect custom models to
  disable HF forward patching (patch_inner_model/patch_causal_lm_model)
- Create per-part optimizers/schedulers for PP model stages
- Add PPLossAdapter (stateful loss for PP schedule), PPLogprobsCapturer,
  pp_forward_backward, and PP broadcast utilities in train.py
- Add gradient accumulation loop in _train_pp matching automodel's
  train_ft.py pattern (prepare_for_grad_accumulation/final_backward)
- Force-reset PP schedule state before training to prevent stale stage
  initialization from prior schedule.eval() calls
- Add PP-aware logprobs (_get_logprobs_pp), topk (_get_topk_logits_pp),
  reference model swap, weight streaming, checkpoint, and device moves
- Support dynamic sequence lengths via AutoPipeline.update_seq_len()
- Read pp_size in LMPolicy for correct data sharding with PP

Verified: SFT PP=2 loss/grad_norm match PP=1 baseline within 1%.
GRPO PP=2 with Interleaved1F1B completes end-to-end (generation,
logprobs, training, weight streaming) with matching metrics.

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…backend

Refactors the PP implementation introduced in aaf2523:

- Extract ModelHandle wrapper (model_handle.py) for unified nn.Module/AutoPipeline
  interface, eliminating most `if pp_enabled` branching
- Extract pipeline_parallel.py with PPLossAdapter, PPLogprobsCapturer,
  PPTopkCapturer, pp_forward_backward, and broadcast helpers
- Extract build_pipeline_config() helper in setup.py for PipelineConfig resolution
- Always-list optimizers/schedulers in ModelAndOptimizerState for PP multi-stage
- PP-aware checkpoint save/load passing model parts list and optimizer list
- Broadcast PP loss from last stage so validation loss is reported correctly
- Add DPO PP=2 config and test script

Verified: SFT PP=2 (loss 5.73->3.14, val_loss 6.33), GRPO PP=2 (gen_kl 0.0003,
grad_norm 0.178, reward 0.584) all match PP=1 baselines within 0.5%.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Adds sequence packing support using TE's THD format
for the automodel backend, with CP (context parallelism) and PP support.

Key changes:
- data.py: THDBatch dataclass, pack_for_thd() for CP-padded packing with
  token_mask-based labels (-100 for prompt+padding)
- train.py: model_forward THD path, LossPostProcessor with
  SequencePackingLossWrapper + context_parallel_group (same as Megatron),
  LogprobsPostProcessor THD per-sequence path
- pipeline_parallel.py: PPLossAdapter with SequencePackingLossWrapper
- model_utils.py: CP-only path in get_next_token_logprobs_from_logits
- Configs and test scripts for seqpack, CP+seqpack, PP+seqpack

Verified: all seqpack configs match non-seqpack baselines within 1%.
Seqpack gives 3.8x speedup at seq_len=4096.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
… configs

Add _AllGatherCPNoReduceBackward that skips all-reduce in backward pass
when each CP rank computes loss on allgathered (identical) logprobs,
preventing gradient doubling. Fix EP sizes: CP=2 configs now inherit
EP=8 from parent (DP+CP share EP axis); PP2+CP2 configs corrected to EP=4.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…lel.py

Extract shared PP eval logic from _get_logprobs_pp and _get_topk_logits_pp
into pp_forward_with_post_processing (single-chunk) and pp_forward_loop
(full iteration), mirroring the non-PP forward_with_post_processing_fn
pattern. Merge identical PPLogprobsCapturer/PPTopkCapturer into single
PPLogitsCapturer. Also fixes missing pad_batch_for_pp and _pp_current_seq_len
reset in the topk PP path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
…ackward_loop

Move the PP gradient-accumulation inner loop (~140 lines) from _train_pp
into pp_train_forward_backward_loop() in pipeline_parallel.py, mirroring
automodel_forward_backward() for the non-PP path. Also extract shared
_reset_pp_schedule_state() helper used by both eval and train PP paths,
with force=True for train to handle eval→train schedule transitions.

Fixes: input_lengths slicing used unclamped indices (bug in partial batches),
inconsistent hasattr check between eval/train paths, redundant loop-invariant
recomputation.

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Tests for PPLogitsCapturer (capture, reset, detach, raw tensor),
PPLossAdapter (microbatch splitting, reset, dp*cp scaling, call index,
2D logit unsqueeze), pad_batch_for_pp (padding, no-op, dtype), and
_reset_pp_schedule_state (BSHD, THD, HF, force mode).

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai hemildesai changed the title feat: pipeline parallelism + context parallelism + sequence packing for automodel feat: add PP + CP + seqpack support for automodel backend Apr 9, 2026
@hemildesai
Copy link
Copy Markdown
Contributor Author

/claude review

These are nightly test output artifacts that should not be committed.

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large, well-tested PR. The PP + CP + seqpack matrix is thoroughly covered by nightly scripts and unit tests, and the loss/metric verification in the test plan is solid.

One potential bug flagged inline: _all_params_generator always uses parts[0] for FQN adaptation on the owner rank, inconsistent with prepare_refit_info which correctly passes the owning part.

…rator

_maybe_adapt_tensor_to_hf was always called with parts[0] regardless of
which PP stage owned the parameter. Pass the actual owning part from
_get_local_param_tensor so FQN adaptation is correct for all stages.

Signed-off-by: hemildesai <hemild@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai hemildesai added the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label Apr 10, 2026
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 47442fa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant