Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
defaults: ./dpo-llama3.1-8b-tulu3-1n8g-fsdp2tp1.yaml
policy:
train_micro_batch_size: 2
dtensor_cfg:
_v2: true
pipeline_parallel_size: 2
automodel_kwargs:
force_hf: true
pipeline_config:
_target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig
pp_schedule: 1F1B
pp_microbatch_size: 1
checkpointing:
checkpoint_dir: results/dpo-llama3.1-8b-tulu3-1n8g-pp2
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml
policy:
sequence_packing:
enabled: true
dynamic_batching:
enabled: false
dtensor_cfg:
context_parallel_size: 2
checkpointing:
checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-cp2ep4-seqpack
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml
policy:
dynamic_batching:
enabled: false
dtensor_cfg:
context_parallel_size: 2
expert_parallel_size: 4
checkpointing:
checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-cp2ep4
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml
policy:
sequence_packing:
enabled: true
dynamic_batching:
enabled: false
checkpointing:
checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-ep8-seqpack
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults: ./grpo-moonlight-16b-automodel-1n8g-pp2ep4.yaml
policy:
max_total_sequence_length: 2048
sequence_packing:
enabled: true
train_mb_tokens: 2048
data:
max_input_seq_length: 2048
checkpointing:
checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-pp2ep4-seqpack
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults: ./grpo-moonlight-16b-automodel-1n8g-ep8.yaml
policy:
train_micro_batch_size: 4
dtensor_cfg:
pipeline_parallel_size: 2
expert_parallel_size: 4
automodel_kwargs:
pipeline_config:
_target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig
pp_schedule: Interleaved1F1B
pp_microbatch_size: 2
dynamic_batching:
enabled: false
checkpointing:
checkpoint_dir: results/grpo-moonlight-16b-automodel-1n8g-pp2ep4
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults: ./sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml
policy:
dtensor_cfg:
context_parallel_size: 2
sequence_packing:
enabled: true
checkpointing:
checkpoint_dir: results/sft-gpt-oss-20b-1n8g-cp2ep4-seqpack-automodel
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults: ./sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml
policy:
sequence_packing:
enabled: true
checkpointing:
checkpoint_dir: results/sft-gpt-oss-20b-1n8g-fsdp8ep8-seqpack-automodel
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
defaults: ./sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml
sft:
val_at_start: false
policy:
sequence_packing:
enabled: true
dtensor_cfg:
context_parallel_size: 2
checkpointing:
checkpoint_dir: results/sft-gpt-oss-20b-1n8g-pp2cp2ep2-seqpack-automodel
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
defaults: ./sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml
sft:
val_micro_batch_size: 4
policy:
dtensor_cfg:
pipeline_parallel_size: 2
expert_parallel_size: 4
automodel_kwargs:
pipeline_config:
_target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig
pp_schedule: 1f1b
pp_microbatch_size: 4
checkpointing:
checkpoint_dir: results/sft-gpt-oss-20b-1n8g-pp2ep4-automodel
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
defaults: ./sft-gpt-oss-20b-1n8g-pp2ep4-automodel.yaml
sft:
val_at_start: false
policy:
sequence_packing:
enabled: true
checkpointing:
checkpoint_dir: results/sft-gpt-oss-20b-1n8g-pp2ep4-seqpack-automodel
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml
policy:
sequence_packing:
enabled: false
dtensor_cfg:
context_parallel_size: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml
policy:
dtensor_cfg:
context_parallel_size: 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml
policy:
sequence_packing:
enabled: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defaults: ../../sft.yaml
policy:
model_name: moonshotai/Moonlight-16B-A3B-Instruct
train_global_batch_size: 128
train_micro_batch_size: 8
max_total_sequence_length: 512
sequence_packing:
enabled: true
dtensor_cfg:
expert_parallel_size: 8
automodel_kwargs:
backend:
_target_: nemo_automodel.components.models.common.utils.BackendConfig
attn: te
linear: te
rms_norm: te
enable_deepep: true
rope_fusion: false
enable_hf_state_dict_adapter: true
enable_fsdp_optimizations: false
make_sequence_length_divisible_by: 4
sft:
val_at_start: false
checkpointing:
enabled: false
cluster:
gpus_per_node: 8
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults: ./sft-moonlight-16b-1n8g-fsdp8ep8-seqpack-automodel.yaml
policy:
dtensor_cfg:
pipeline_parallel_size: 2
context_parallel_size: 2
expert_parallel_size: 4
automodel_kwargs:
pipeline_config:
_target_: nemo_automodel.components.distributed.pipelining.config.PipelineConfig
pp_schedule: 1f1b
pp_microbatch_size: 4
94 changes: 90 additions & 4 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,68 @@
)


class _AllGatherCPNoReduceBackward(torch.autograd.Function):
"""Like AllGatherCPTensor but WITHOUT all-reduce in backward.

Used when each CP rank computes loss on allgathered (identical) logprobs.
In this case, all ranks produce the same gradient, so all-reduce would
double it. This version just slices the gradient to the local chunk.
"""

@staticmethod
def forward(ctx, tensor, cp_group, seq_dim=1):
cp_size = torch.distributed.get_world_size(cp_group)
cp_rank_chunks = [torch.empty_like(tensor) for _ in range(cp_size)]
torch.distributed.all_gather(cp_rank_chunks, tensor, group=cp_group)

# Undo dual-chunk-swap ordering
tensor_chunks = []
for chunk in cp_rank_chunks:
tensor_chunks.extend(torch.chunk(chunk, chunks=2, dim=seq_dim))
chunk_indices = []
for r in range(cp_size):
chunk_indices.append(r)
chunk_indices.append(2 * cp_size - r - 1)
pairs = sorted(zip(tensor_chunks, chunk_indices), key=lambda t: t[1])
ret = torch.cat([c for c, _ in pairs], dim=seq_dim)

ctx.seq_dim = seq_dim
ctx.cp_group = cp_group
return ret

@staticmethod
def backward(ctx, grad_output):
cp_size = torch.distributed.get_world_size(ctx.cp_group)
cp_rank = torch.distributed.get_rank(ctx.cp_group)
seq_dim = ctx.seq_dim
# NO all-reduce — just select this rank's chunk
grad_output = grad_output.view(
*grad_output.shape[:seq_dim],
2 * cp_size,
grad_output.shape[seq_dim] // (2 * cp_size),
*grad_output.shape[seq_dim + 1 :],
)
index = torch.tensor(
[cp_rank, 2 * cp_size - cp_rank - 1], device="cpu", pin_memory=True
).cuda(non_blocking=True)
grad_input = grad_output.index_select(seq_dim, index)
grad_input = grad_input.view(
*grad_input.shape[:seq_dim], -1, *grad_input.shape[seq_dim + 2 :]
)
return grad_input, None, None


_SELF_TP_GROUP: Optional[torch.distributed.ProcessGroup] = None


def _get_self_tp_group() -> torch.distributed.ProcessGroup:
"""Cached single-rank TP group for CP-only logprob computation."""
global _SELF_TP_GROUP
if _SELF_TP_GROUP is None:
_SELF_TP_GROUP = torch.distributed.new_group([torch.distributed.get_rank()])
return _SELF_TP_GROUP


@torch.no_grad()
def _compute_distributed_log_softmax(
vocab_parallel_logits: torch.Tensor, group: torch.distributed.ProcessGroup
Expand Down Expand Up @@ -941,6 +1003,7 @@ def from_parallel_logits_to_logprobs(
tp_group: torch.distributed.ProcessGroup,
inference_only: bool = False,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
cp_allgather_no_reduce: bool = False,
chunk_size: Optional[int] = None,
sampling_params: Optional[TrainingSamplingParams] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -1020,10 +1083,12 @@ def from_parallel_logits_to_logprobs(
).contiguous()

if cp_size > 1:
# we need to gather the logits by context parallelism
logprobs = allgather_cp_sharded_tensor(
logprobs, cp_group, seq_dim=1
) # , unpadded_seqlen=target.shape[1])
if cp_allgather_no_reduce:
# No all-reduce in backward — used when each CP rank computes
# loss on allgathered logprobs (avoids 2x gradient).
logprobs = _AllGatherCPNoReduceBackward.apply(logprobs, cp_group, 1)
else:
logprobs = allgather_cp_sharded_tensor(logprobs, cp_group, seq_dim=1)

if pad_len > 0:
logprobs = logprobs[:, :-pad_len]
Expand Down Expand Up @@ -1394,6 +1459,27 @@ def get_next_token_logprobs_from_logits(
# slice off to the correct length to remove potential CP padding
logprobs = logprobs[:, : input_ids.shape[1] - 1]

elif context_parallel_group is not None and not isinstance(
next_token_logits, torch.distributed.tensor.DTensor
):
# CP-only path (no TP): automodel backend with full vocab per rank.
# TE's thd_get_partitioned_indices uses dual-chunk-swap (same as
# _get_tokens_on_this_cp_rank), so from_parallel_logits_to_logprobs works.
tp_group = _get_self_tp_group()
vocab_size = next_token_logits.shape[-1]
logprobs = from_parallel_logits_to_logprobs(
next_token_logits,
input_ids,
vocab_start_index=0,
vocab_end_index=vocab_size,
tp_group=tp_group,
inference_only=False,
cp_group=context_parallel_group,
cp_allgather_no_reduce=True, # avoid 2x gradient from allgather backward
sampling_params=sampling_params,
)
logprobs = logprobs[:, : input_ids.shape[1] - 1]

elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits,
Expand Down
Loading
Loading