Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
429 changes: 429 additions & 0 deletions CROSS_TOKENIZER_README.md

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions eval_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Model,top_k,MATH (%),MATH (correct),MATH (total),MMLU (%),MMLU (correct),MMLU (total),wandb
Llama-3.1-8B (teacher),--,12.56,628,5000,35.14,4935,14042,
Llama-3.2-1B (pretrained baseline),--,5.64,282,5000,23.06,3238,14042,
Distillation forward KL (50 steps),64,5.52,276,5000,23.25,3265,14042,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/uc8nnmh1
Distillation forward KL (1000 steps),64,5.84,292,5000,26.24,3684,14042,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/uc8nnmh1
Delta from baseline (top_k=64),,+0.20,+10,,+3.18,+446,,
Distillation forward KL (50 steps),4096,5.27,27,512,26.56,136,512,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/8a7dictz
Distillation forward KL (1000 steps),4096,5.27,27,512,26.56,136,512,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/8a7dictz
SFT (50 steps),--,4.06,203,5000,16.79,2358,14042,https://wandb.ai/nvidia/nemo-sft-arrow-eval/runs/w5fiqlbw
SFT (1000 steps),--,7.30,365,5000,22.13,3108,14042,https://wandb.ai/nvidia/nemo-sft-arrow-eval/runs/w5fiqlbw
Delta from baseline (SFT),,+1.66,+83,,-0.93,-130,,
245 changes: 245 additions & 0 deletions examples/configs/cross_tokenizer_off_policy_arrow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Cross-Tokenizer Off-Policy Distillation Configuration
# Student: Llama-3.2-1B (128K vocab) <- Teacher: Qwen3-8B (151K vocab)
#
# Requires a precomputed projection matrix. Generate with:
# python nemo_rl/algorithms/x_token/minimal_projection_via_multitoken.py \
# --student-model meta-llama/Llama-3.2-1B \
# --teacher-model Qwen/Qwen3-8B-Base
#
# Then optionally enforce exact matches:
# python nemo_rl/algorithms/x_token/reapply_exact_map.py \
# --student-model meta-llama/Llama-3.2-1B \
# --teacher-model Qwen/Qwen3-8B-Base \
# --initial-projection-path cross_tokenizer_data/transformation_counts_via_multitoken.pt

token_aligner:
enabled: true
projection_matrix_path: "cross_tokenizer_data/projection_map_Llama-3.2_to_Phi-4-mini-instruct_multitoken_top_32_double_special.pt"
use_sparse_format: false
loss_type: "KL"
exact_token_match_only: false
temperature: 1.0
vocab_topk: 8192
reverse_kl: false
projection_matrix_multiplier: 1.0
max_comb_len: 4
learnable: false
project_teacher_to_student: false # Remove this
use_char_offset: false
use_align_fast: true
use_cuda_dp: false
dp_chunk_size: 128

distillation:
num_prompts_per_step: 768
num_generations_per_prompt: 1
max_num_steps: 80000
max_num_epochs: 1
val_period: 1000
val_at_start: false
max_val_samples: 128
val_batch_size: 64
topk_logits_k: 8192
use_ipc: true
loss_on_all_tokens: true
seed: 42
# Number of CPU processes for cross-tokenizer decode/encode/align.
# Heuristic currently used: total GPUs / 2 (16 nodes * 8 GPUs = 128 -> 64 workers).
# Note: this is workload-dependent; tune per batch size and cluster shape.
cross_tokenizer_num_workers: 64

loss_fn:
loss_type: "KL"
temperature: 1.0
vocab_topk: 8192
exact_token_match_only: false
reverse_kl: false
project_teacher_to_student: false
gold_loss: true
xtoken_loss: true
ce_loss_scale: 0.1
dynamic_loss_scaling: true

checkpointing:
enabled: true
checkpoint_dir: "checkpoints/cross-tokenizer-distillation-llama1b-qwen8b"
metric_name: "train:loss"
higher_is_better: false
keep_top_k: 3
save_period: 10
model_save_format: "safetensors"
save_consolidated: false

policy:
model_name: "meta-llama/Llama-3.2-1B"
tokenizer:
name: "meta-llama/Llama-3.2-1B"
chat_template: null
train_global_batch_size: 768
train_micro_batch_size: 1
max_total_sequence_length: 4096
precision: "bfloat16"

dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.02
end_factor: 1.0
total_iters: 4000
- name: "torch.optim.lr_scheduler.CosineAnnealingLR"
kwargs:
T_max: 76000
eta_min: 0.0
- milestones: [4000]

generation:
backend: "vllm"
max_new_tokens: 2048
temperature: 0.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false
precision: "bfloat16"
kv_cache_dtype: "auto"
tensor_parallel_size: 1
pipeline_parallel_size: 1
expert_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: 2048
enforce_eager: false
use_deep_gemm: false
num_last_layers_in_bf16: 0
num_first_layers_in_bf16: 0
distributed_executor_backend: null
colocated:
enabled: true
resources:
gpus_per_node: null
num_nodes: null

teacher:
model_name: "microsoft/Phi-4-mini-instruct"
tokenizer:
name: "microsoft/Phi-4-mini-instruct"
chat_template: null
precision: "bfloat16"
train_global_batch_size: 768
train_micro_batch_size: 1
logprob_batch_size: 1
max_total_sequence_length: 4096
max_grad_norm: 1.0
logprob_chunk_size: null
offload_optimizer_for_logprob: false

dtensor_cfg:
enabled: true
_v2: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
sequence_length_round: 64

sequence_packing:
enabled: false
train_mb_tokens: 4096
logprob_mb_tokens: 4096
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-5
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5
foreach: false
fused: false

generation: null

data:
dataset_name: "arrow_text"
arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-00[0-4][0-9][0-9]-of-02476.arrow"
# arrow_files="/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-000[0-5][0-9]-of-02476.arrow"
prompt_file: null
max_input_seq_length: 4096
shuffle: true

eval:
val_period: 1000
val_at_start: false
max_val_samples: 512
val_batch_size: 64
max_rollout_turns: 1
benchmarks:
math:
dataset_name: "math"
prompt_file: "examples/prompts/cot.txt"
env:
num_workers: 8
mmlu:
dataset_name: "mmlu"
prompt_file: "examples/prompts/mmlu.txt"
env:
num_workers: 8
verifier_type: "multilingual_multichoice"
mmlu_5shot:
dataset_name: "mmlu"
prompt_file: "examples/prompts/mmlu.txt"
num_few_shot: 5
env:
num_workers: 8
verifier_type: "multilingual_multichoice"

logger:
log_dir: "logs/cross-tokenizer-distillation-llama1b-qwen8b"
num_val_samples_to_print: 5
wandb_enabled: true
swanlab_enabled: false
mlflow_enabled: false
tensorboard_enabled: false
monitor_gpus: true
wandb:
project: "nemo-cross-tokenizer-distillation"
name: "cross-tokenizer-llama1b-qwen8b-bs768"
gpu_monitoring:
collection_interval: 10
flush_interval: 10

cluster:
gpus_per_node: 8
num_nodes: 16
32 changes: 32 additions & 0 deletions examples/configs/dist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
defaults: distillation_math.yaml
distillation:
num_prompts_per_step: 512
max_num_steps: 500
val_batch_size: 512
val_period: 20
loss_fn:
kl_type: reverse
checkpointing:
model_save_format: "torch_save"
keep_top_k: 3
checkpoint_dir: checkpoints/distillation-qwen3-32b-to-4b-base-long
policy:
model_name: Qwen/Qwen3-4B-Base
train_global_batch_size: 512
max_total_sequence_length: 20480
generation:
colocated:
enabled: true # Try setting to false if IPC issues persist
vllm_cfg:
tensor_parallel_size: 1
gpu_memory_utilization: 0.7
teacher:
model_name: Qwen/Qwen3-32B
max_total_sequence_length: 20480
logger:
log_dir: logs/distillation-qwen3-32b-to-4b-base-long
wandb:
project: nemo-rl
name: distillation-qwen3-32b-to-4b-base-long
cluster:
num_nodes: 1
2 changes: 1 addition & 1 deletion examples/configs/distillation_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ teacher:
dtensor_cfg:
<<: *DTENSOR_BASE
context_parallel_size: 2
tensor_parallel_size: 4
tensor_parallel_size: 2

data:
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
Expand Down
17 changes: 17 additions & 0 deletions examples/configs/evals/llama_math_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Math evaluation for Llama 3.2 1B (base model)
# Override generation.model_name with checkpoint path at runtime:
# generation.model_name=/path/to/checkpoint/policy/weights
defaults: "eval.yaml"

generation:
model_name: "meta-llama/Llama-3.2-1B"
vllm_cfg:
max_model_len: 2048

tokenizer:
name: ${generation.model_name}
chat_template: null # base model, no chat formatting

data:
prompt_file: "examples/prompts/cot.txt"
dataset_name: "math"
21 changes: 21 additions & 0 deletions examples/configs/evals/llama_mmlu_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MMLU evaluation for Llama 3.2 1B (base model)
# Override generation.model_name with checkpoint path at runtime:
# generation.model_name=/path/to/checkpoint/policy/weights
defaults: "eval.yaml"

generation:
model_name: "meta-llama/Llama-3.2-1B"
vllm_cfg:
max_model_len: 2048

tokenizer:
name: ${generation.model_name}
chat_template: null # base model, no chat formatting

data:
prompt_file: "examples/prompts/mmlu.txt"
dataset_name: "mmlu"

env:
math:
verifier_type: "multilingual_multichoice"
Loading
Loading