diff --git a/CROSS_TOKENIZER_README.md b/CROSS_TOKENIZER_README.md new file mode 100644 index 0000000000..eb9917c052 --- /dev/null +++ b/CROSS_TOKENIZER_README.md @@ -0,0 +1,429 @@ +# Cross-Tokenizer Off-Policy Distillation for NeMo RL + +This document describes the cross-tokenizer off-policy distillation feature built +on top of NeMo RL. It enables knowledge distillation between teacher and student +models that use **different tokenizers and vocabularies** (e.g., Qwen 8B teacher +distilling into a Llama 1B student). + +--- + +## Table of Contents + +1. [Base Commit](#base-commit) +2. [Overview](#overview) +3. [Architecture](#architecture) +4. [Commit History](#commit-history) +5. [New Files](#new-files) +6. [Modified Existing Files](#modified-existing-files) +7. [How to Run](#how-to-run) +8. [Configuration Reference](#configuration-reference) +9. [Design Decisions](#design-decisions) + +--- + +## Base Commit + +All changes are built on top of the **NeMo RL `v0.5.0`** release. + +| Field | Value | +|----------------|----------------------------------------------------------------------------| +| Repository | [NVIDIA-NeMo/RL](https://github.com/NVIDIA-NeMo/RL) (the `origin` remote) | +| Tag | `v0.5.0` | +| Commit hash | `6c7089300fded94abfa49bb9cbf9eb357d862461` | +| Commit message | `cp: Bump protobuf to 6.33.5 and python-multipart to 0.0.22 into r0.5.0 (#1851)` | +| Branch | `xtoken/off-policy-distillation` | + +To verify locally: + +```bash +git log --oneline 6c708930 -1 +# Expected: 6c708930 cp: Bump protobuf to 6.33.5 and python-multipart to 0.0.22 into `r0.5.0` (#1851) + +git tag --contains 6c708930 +# Expected: v0.5.0 +``` + +--- + +## Overview + +Standard NeMo RL distillation assumes the teacher and student share the same +tokenizer. This fork removes that constraint by adding two major capabilities: + +1. **Off-policy distillation** -- A training loop that uses a fixed dataset of + text (Arrow files) instead of generating responses on-policy. The teacher + produces logits for the fixed responses and the student aligns to them via KL + divergence. This is simpler and cheaper than on-policy distillation because + there is no student generation step or environment needed. + +2. **Cross-tokenizer support (TokenAligner)** -- When the teacher and student + use different tokenizers (e.g., Qwen's 151K-token vocabulary vs. Llama's + 128K-token vocabulary), a precomputed *projection matrix* maps student + probabilities into the teacher's vocabulary space. A dynamic-programming + *token alignment* algorithm aligns the two tokenizations of each text at + the sequence level so the KL loss is computed at comparable positions. + +Together, these enable distillation from any teacher to any student regardless +of their tokenizer, which was not previously possible in NeMo RL. + +--- + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Training Step Overview │ +├─────────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. Load batch of text from Arrow dataset │ +│ ├── Tokenize with student tokenizer → student_input_ids │ +│ └── Tokenize with teacher tokenizer → teacher_input_ids │ +│ │ +│ 2. Token alignment (TokenAligner.align) │ +│ └── DP alignment of student & teacher token sequences │ +│ → aligned_pairs: list of (s_start, s_end, t_start, t_end) │ +│ │ +│ 3. Teacher forward pass (via IPC) │ +│ ├── Teacher model produces full-vocab logits │ +│ ├── Log-softmax computed distributedly (TP-aware) │ +│ └── Stored in GPU IPC buffers (no Ray data transfer) │ +│ │ +│ 4. Student forward pass + loss │ +│ ├── Student model produces logits │ +│ ├── Reads teacher logits from IPC buffers │ +│ ├── CrossTokenizerDistillationLossFn computes: │ +│ │ ├── Project student probs → teacher vocab via projection matrix │ +│ │ ├── Chunk-average over aligned spans │ +│ │ └── KL divergence per chunk, masked_mean reduction │ +│ └── Backprop through student only │ +│ │ +│ 5. (Optional) Periodic MATH/MMLU evaluation via colocated vLLM │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +### Key components + +- **TokenAligner** (`nemo_rl/algorithms/x_token/tokenalign.py`): Core module + that handles vocabulary projection and sequence-level token alignment. + - *Projection matrix*: A precomputed mapping from student vocabulary to + teacher vocabulary. Each student token maps to a weighted set of teacher + tokens (e.g., top-32 most likely correspondences). + - *DP alignment*: Dynamic-programming algorithm that aligns student and + teacher token sequences allowing 1:1, 1:N, N:1, and N:M mappings. + Uses anchor-based optimization (unique n-gram matches) to avoid + quadratic cost on long sequences. + - *Token canonicalization*: Normalizes tokenizer-specific representations + (SentencePiece byte tokens, space prefixes, Unicode artifacts) so the + DP alignment can match across tokenizer families. + +- **Projection matrix generators** (`nemo_rl/algorithms/x_token/`): + Standalone scripts to precompute the projection matrix offline. + - `minimal_projection_via_multitoken.py` -- Multi-token analysis: tokenize + each student token string with the teacher tokenizer and distribute + probability mass across the resulting sub-tokens. + - `minimal_projection_generator.py` -- Embedding similarity: use LLM + embedding layers to compute cosine similarity between vocabularies. + - `reapply_exact_map.py` -- Post-processing: force exact 1:1 mappings for + tokens that are identical across both tokenizers. + +- **IPC teacher logits**: Teacher logits are passed between the teacher and + student forward passes using CUDA IPC handles instead of serializing through + Ray. This avoids expensive CPU-GPU-CPU round-trips for large logit tensors. + +--- + +## Commit History + +Eight commits on top of `v0.5.0`, listed oldest to newest: + +| Hash | Message | +|------------|---------------------------------------------------------------------| +| `0658b8d2` | Add off-policy distillation with MATH/MMLU eval and IPC optimization | +| `668c37ed` | Commit before refactoring | +| `13066d63` | Simplify off-policy distillation IPC path and config | +| `f733e57c` | Working IPC TP=1 | +| `3204ac78` | Per-microbatch IPC teacher logits with TP=4 support | +| `d4de1d8f` | Clean up unused scripts and old distillation module | +| `f9fe64a5` | Add IPC/non-IPC toggle for off-policy distillation | +| `58a1bd71` | Integrate cross-tokenizer distillation (TokenAligner) into NeMo RL | + +--- + +## New Files + +### Core algorithm + +| File | Lines | Purpose | +|------|-------|---------| +| `nemo_rl/algorithms/off_policy_distillation.py` | ~1,100 | Off-policy distillation training loop. Contains `off_policy_distillation_train()` which iterates over a fixed dataset, runs teacher inference, and trains the student with KL loss. Handles checkpointing, validation, eval hooks, and cross-tokenizer data preparation. Created because the existing `distillation.py` is on-policy (generates student responses via rollout), which is unnecessary and expensive when training on a fixed text corpus. | + +### Cross-tokenizer module (`nemo_rl/algorithms/x_token/`) + +| File | Lines | Purpose | +|------|-------|---------| +| `tokenalign.py` | ~4,300 | Core `TokenAligner` class (`nn.Module`). Handles projection matrix loading/management, DP-based sequence alignment, token canonicalization, and multiple KL loss computation strategies (standard, optimized with vocab top-k, gold loss with common/uncommon vocab split). This is the central piece that makes cross-tokenizer distillation possible. | +| `minimal_projection_via_multitoken.py` | ~930 | Generates projection matrices via multi-token analysis. For each student token, tokenizes its string with the teacher tokenizer and distributes weight across the resulting sub-tokens with exponential decay. Preferred method for generating projection matrices. | +| `minimal_projection_generator.py` | ~570 | Generates projection matrices via embedding cosine similarity. Uses LLM first-layer embeddings or sentence transformers. Alternative to the multi-token method. | +| `reapply_exact_map.py` | ~230 | Post-processes a projection matrix to enforce perfect 1:1 mappings for tokens that are identical across both tokenizers (e.g., punctuation, digits). | +| `sort_and_cut_projection_matrix.py` | ~440 | Utility to sort projection matrix rows by weight and apply a top-k cutoff. Includes optional Sinkhorn renormalization. | +| `__init__.py` | 3 | Exports `TokenAligner`. | + +### Training entry points and configs + +| File | Purpose | +|------|---------| +| `examples/run_off_policy_distillation_arrow_with_eval.py` | Main training script. Extends `off_policy_distillation_train()` with periodic generation-based evaluation on MATH and MMLU using colocated vLLM. Handles cross-tokenizer setup when `token_aligner.enabled: true`. | +| `examples/configs/cross_tokenizer_off_policy_arrow.yaml` | Reference YAML config for Llama-3.2-1B (student) with Qwen3-8B-Base (teacher). Includes all token_aligner, loss_fn, eval, and cluster settings. | +| `submit_cross_tokenizer.sh` | SLURM submission script for the cross-tokenizer experiment. Supports chained job dependencies (`-n N` for sequential restarts). | + +### Dataset support + +| File | Purpose | +|------|---------| +| `nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py` | `ArrowTextDataset` class for loading Arrow files with a `text` column. Wraps each text as an assistant message for SFT-style training. Created because the existing dataset classes expected prompt-response pairs or specific dataset formats, not raw text Arrow files. | + +### Other new files + +| File | Purpose | +|------|---------| +| `examples/run_off_policy_distillation_arrow.py` | Simpler off-policy script without evaluation (same-tokenizer only). | +| `examples/run_sft_arrow_with_eval.py` | SFT training on Arrow data with MATH/MMLU eval. Used as reference for the eval integration pattern. | +| `examples/configs/llama_off_policy_arrow.yaml` | Config for same-tokenizer off-policy distillation (Llama teacher + Llama student). | + +--- + +## Modified Existing Files + +### `nemo_rl/models/policy/lm_policy.py` + +**What changed:** Added `teacher_forward()` method, extended `train()` with `is_teacher`, `teacher_logits`, and `topk_logits` parameters, and replaced hard `config["key"]` accesses with safer `config.get("key", {}).get(...)` patterns. + +**Why:** The existing `Policy` class had no concept of a teacher-only forward pass. For IPC-based distillation, the teacher needs to run a forward pass that stores logits in GPU IPC buffers without returning data through Ray. The `train()` method was extended so the same worker infrastructure can handle both teacher inference and student training in a single call. The safer config accesses were needed because the teacher policy config omits optional keys like `dynamic_batching` and `sequence_packing` that the original code assumed were always present. + +Key additions: +- `teacher_forward()` dispatches a teacher-only forward to workers, storing results in IPC buffers. +- `train()` gains `is_teacher=True` mode: skips optimizer, returns IPC handles instead of loss. +- `train()` gains `teacher_logits` parameter: when provided, each worker reads teacher logits from IPC handles for its microbatch. +- Config accesses like `config["dynamic_batching"]["enabled"]` changed to `config.get("dynamic_batching", {}).get("enabled", False)` to avoid `KeyError` when these sections are absent. + +### `nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` + +**What changed:** Extended `train()` to handle teacher inference mode (IPC buffer allocation, top-k logit extraction, distributed log-softmax) and student training mode (reading teacher logits from IPC handles). Disabled temperature scaling during training. Added debug logging for NaN detection. + +**Why:** The DTensor worker is where the actual model forward pass happens. To support IPC-based distillation: +- In teacher mode (`is_teacher=True`), the worker runs the model, computes distributed log-softmax across TP ranks, optionally extracts top-k logits, and stores results in pre-allocated CUDA IPC buffers. This avoids serializing large logit tensors through Ray. +- In student mode, the worker reads teacher logits from IPC handles (using `rebuild_cuda_tensor_from_ipc`) and passes them to the loss function alongside the student's own logits. +- Temperature scaling was being applied during training, which distorts the KL divergence computation. It is now skipped (`skip=True`) during training and only applied during generation/inference. + +### `nemo_rl/algorithms/loss_functions.py` + +**What changed:** Extended `DistillationLossFn.__call__()` with three code paths (IPC top-k, IPC full-logprob, standard data-dict), and added the entirely new `CrossTokenizerDistillationLossFn` class (~560 lines). + +**Why:** +- The original `DistillationLossFn` only supported a single path where teacher top-k logits were pre-computed and passed in the data dict. The IPC paths were added to avoid materializing teacher logits on CPU (they stay on GPU in IPC buffers). +- `CrossTokenizerDistillationLossFn` is new and handles the case where teacher and student have different vocabularies. It uses the TokenAligner's projection matrix to map student probabilities into the teacher's vocabulary space, then computes chunk-averaged KL divergence over aligned token spans. It also supports a *gold loss* variant that splits the vocabulary into common tokens (direct KL) and uncommon tokens (sorted L1 / Universal Likelihood Distillation). + +### `nemo_rl/data/datasets/response_datasets/__init__.py` + +**What changed:** Registered `ArrowTextDataset` in the dataset factory so `dataset_name: "arrow_text"` works in config files. + +**Why:** The existing factory had no support for raw text Arrow files. Adding the registration allows the off-policy training scripts to load Arrow datasets through the standard NeMo RL data pipeline. + +### `nemo_rl/algorithms/distillation.py` + +**What changed:** Minor compatibility adjustments. + +**Why:** Small fixes to ensure the existing on-policy distillation module works alongside the new off-policy code without import conflicts. + +--- + +## How to Run + +### Prerequisites + +- NeMo RL environment (container or `uv` virtual env) based on `v0.5.0` +- Access to teacher and student model weights on HuggingFace (e.g., `Qwen/Qwen3-8B-Base`, `meta-llama/Llama-3.2-1B`) +- Training data as Arrow files with a `text` column +- SLURM cluster with GPU nodes + +### Step 1: Generate the projection matrix + +The projection matrix maps student vocabulary to teacher vocabulary. Generate it +once offline: + +```bash +# Multi-token method (recommended) +python nemo_rl/algorithms/x_token/minimal_projection_via_multitoken.py \ + --student-model meta-llama/Llama-3.2-1B \ + --teacher-model Qwen/Qwen3-8B-Base + +# Optionally enforce exact matches for identical tokens +python nemo_rl/algorithms/x_token/reapply_exact_map.py \ + --student-model meta-llama/Llama-3.2-1B \ + --teacher-model Qwen/Qwen3-8B-Base \ + --initial-projection-path cross_tokenizer_data/transformation_counts_via_multitoken.pt + +# Optionally sort and cut to top-k per row +python nemo_rl/algorithms/x_token/sort_and_cut_projection_matrix.py \ + --input cross_tokenizer_data/transformation_counts_via_multitoken.pt \ + --top-k 32 +``` + +The output is a `.pt` file containing `{indices, likelihoods}` tensors. + +### Step 2: Configure the YAML + +Edit `examples/configs/cross_tokenizer_off_policy_arrow.yaml` or create a new +config. The key sections are: + +```yaml +token_aligner: + enabled: true + projection_matrix_path: "path/to/projection_map.pt" + use_sparse_format: false + loss_type: "KL" + vocab_topk: 8192 # Reduce teacher vocab to top-8192 for speed + max_comb_len: 4 # Max tokens in a single DP alignment chunk + +policy: + model_name: "meta-llama/Llama-3.2-1B" # Student + # ... optimizer, scheduler, dtensor config ... + +teacher: + model_name: "Qwen/Qwen3-8B-Base" # Teacher + # ... dtensor config (no optimizer needed) ... + +loss_fn: + loss_type: "KL" + gold_loss: true # Common-vocab KL + uncommon-vocab L1 + xtoken_loss: true # Relaxed exact-map threshold (>=0.6) + ce_loss_scale: 0.1 # Optional next-token CE loss + dynamic_loss_scaling: true + +data: + dataset_name: "arrow_text" + arrow_files: "/path/to/data/*.arrow" + max_input_seq_length: 4096 + +distillation: + use_ipc: true # Required for cross-tokenizer + topk_logits_k: 8192 + num_prompts_per_step: 768 + max_num_steps: 80000 +``` + +### Step 3: Submit the job + +```bash +# Single run +bash submit_cross_tokenizer.sh + +# Chain 5 sequential jobs (each picks up from the last checkpoint) +bash submit_cross_tokenizer.sh -n 5 +``` + +The script submits a SLURM job that: +1. Starts a Ray cluster across all allocated nodes +2. Runs `examples/run_off_policy_distillation_arrow_with_eval.py` +3. Periodically evaluates on MATH and MMLU via colocated vLLM +4. Logs to Weights & Biases + +### Same-tokenizer mode + +If the teacher and student share the same tokenizer, set `token_aligner.enabled: false` +(or omit the `token_aligner` section). The training loop falls back to the +standard `DistillationLossFn` with top-k teacher logits. + +--- + +## Configuration Reference + +### `token_aligner` section + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| `enabled` | bool | `false` | Master switch for cross-tokenizer mode | +| `projection_matrix_path` | str | required | Path to `.pt` projection matrix file | +| `use_sparse_format` | bool | `true` | Load projection as sparse COO (faster for large vocabs) | +| `loss_type` | str | `"KL"` | `"KL"`, `"cross_entropy"`, or `"chunked_ce"` | +| `exact_token_match_only` | bool | `false` | Only use 1:1 aligned positions for loss | +| `temperature` | float | `1.0` | Softmax temperature for KL computation | +| `vocab_topk` | int | `8192` | Reduce teacher vocab to top-k (0 = use all) | +| `reverse_kl` | bool | `false` | Use reverse KL direction | +| `projection_matrix_multiplier` | float | `1.0` | Scaling factor for projection matrix | +| `max_comb_len` | int | `4` | Max combination length for DP alignment | +| `learnable` | bool | `false` | Make projection matrix trainable | +| `project_teacher_to_student` | bool | `false` | Project teacher to student vocab instead | + +### `loss_fn` section (cross-tokenizer) + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| `loss_type` | str | `"KL"` | Loss type | +| `temperature` | float | `1.0` | Softmax temperature | +| `vocab_topk` | int | `8192` | Teacher vocab top-k filtering | +| `exact_token_match_only` | bool | `false` | Restrict to 1:1 aligned positions | +| `reverse_kl` | bool | `false` | Reverse KL direction | +| `gold_loss` | bool | `false` | Common-vocab KL + uncommon-vocab sorted L1 | +| `xtoken_loss` | bool | `false` | Relaxed exact-map threshold (>=0.6 instead of ==1.0) | +| `ce_loss_scale` | float | `0.0` | Weight for auxiliary next-token CE loss (0 = disabled) | +| `dynamic_loss_scaling` | bool | `false` | Scale KL loss to match CE loss magnitude | + +--- + +## Design Decisions + +### Why off-policy distillation? + +On-policy distillation (the existing `distillation.py`) generates student +responses, scores them with an environment, and uses the teacher to compute +logits on those responses. This requires a generation engine, an environment, +and produces different data each step. Off-policy distillation uses a *fixed* +text corpus -- the same data the teacher was trained on. This is: +- **Simpler**: No generation step, no environment, no rollout turns. +- **Cheaper**: No vLLM inference for student generation during training. +- **Deterministic**: Same data every epoch, easier to debug and reproduce. +- **Sufficient for distillation**: When the goal is to transfer the teacher's + language modeling ability (not RL-specific behavior), a fixed corpus works well. + +### Why IPC for teacher logits? + +Teacher logits for a batch of shape `[B, S, V]` (e.g., `[768, 4096, 151936]` +for Qwen 8B) are hundreds of gigabytes. Passing them through Ray would require +serializing to CPU, transferring, and deserializing back to GPU. CUDA IPC handles +allow the student worker to read the teacher's GPU memory directly without any +data movement. This is why `distillation.use_ipc: true` is required for +cross-tokenizer mode. + +### Why chunk-averaged KL? + +When teacher and student tokenize the same text differently, there is no 1:1 +correspondence between all token positions. For example, the word "unhappiness" +might be `["un", "happiness"]` in one tokenizer and `["un", "happ", "iness"]` +in another. The DP alignment finds these correspondences and groups them into +*chunks*. Within each chunk, the teacher and student distributions are averaged +over their respective token spans, renormalized, and compared via KL divergence. +This handles 1:1, 1:N, N:1, and N:M alignments uniformly. + +### Why gold loss? + +The standard projection-based path projects student probabilities into the +teacher's vocabulary space using a precomputed matrix. This introduces +approximation error for tokens that don't have clean 1:1 mappings. The *gold +loss* variant avoids projection entirely for tokens that have exact matches +between vocabularies (e.g., digits, punctuation, common words). For these +"common" tokens, KL is computed directly on the native log-probabilities. For +"uncommon" tokens (no exact mapping), it falls back to sorted L1 on probability +vectors (Universal Likelihood Distillation). This typically gives better gradient +signal for the majority of tokens. + +### Why per-microbatch IPC? + +Large batches are split into microbatches for gradient accumulation. Rather than +storing the entire batch of teacher logits in GPU memory (which may not fit), +each microbatch's teacher logits are stored in a separate IPC buffer. The +student reads only the current microbatch's buffer during its forward pass. +This keeps peak GPU memory proportional to the microbatch size, not the global +batch size. The implementation also supports tensor parallelism (TP > 1) where +each rank stores and reads only its local vocabulary shard. diff --git a/eval_results.csv b/eval_results.csv new file mode 100644 index 0000000000..604f71df2f --- /dev/null +++ b/eval_results.csv @@ -0,0 +1,11 @@ +Model,top_k,MATH (%),MATH (correct),MATH (total),MMLU (%),MMLU (correct),MMLU (total),wandb +Llama-3.1-8B (teacher),--,12.56,628,5000,35.14,4935,14042, +Llama-3.2-1B (pretrained baseline),--,5.64,282,5000,23.06,3238,14042, +Distillation forward KL (50 steps),64,5.52,276,5000,23.25,3265,14042,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/uc8nnmh1 +Distillation forward KL (1000 steps),64,5.84,292,5000,26.24,3684,14042,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/uc8nnmh1 +Delta from baseline (top_k=64),,+0.20,+10,,+3.18,+446,, +Distillation forward KL (50 steps),4096,5.27,27,512,26.56,136,512,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/8a7dictz +Distillation forward KL (1000 steps),4096,5.27,27,512,26.56,136,512,https://wandb.ai/nvidia/nemo-off-policy-distillation-eval/runs/8a7dictz +SFT (50 steps),--,4.06,203,5000,16.79,2358,14042,https://wandb.ai/nvidia/nemo-sft-arrow-eval/runs/w5fiqlbw +SFT (1000 steps),--,7.30,365,5000,22.13,3108,14042,https://wandb.ai/nvidia/nemo-sft-arrow-eval/runs/w5fiqlbw +Delta from baseline (SFT),,+1.66,+83,,-0.93,-130,, diff --git a/examples/configs/cross_tokenizer_off_policy_arrow.yaml b/examples/configs/cross_tokenizer_off_policy_arrow.yaml new file mode 100644 index 0000000000..d5fd7fa2f3 --- /dev/null +++ b/examples/configs/cross_tokenizer_off_policy_arrow.yaml @@ -0,0 +1,245 @@ +# Cross-Tokenizer Off-Policy Distillation Configuration +# Student: Llama-3.2-1B (128K vocab) <- Teacher: Qwen3-8B (151K vocab) +# +# Requires a precomputed projection matrix. Generate with: +# python nemo_rl/algorithms/x_token/minimal_projection_via_multitoken.py \ +# --student-model meta-llama/Llama-3.2-1B \ +# --teacher-model Qwen/Qwen3-8B-Base +# +# Then optionally enforce exact matches: +# python nemo_rl/algorithms/x_token/reapply_exact_map.py \ +# --student-model meta-llama/Llama-3.2-1B \ +# --teacher-model Qwen/Qwen3-8B-Base \ +# --initial-projection-path cross_tokenizer_data/transformation_counts_via_multitoken.pt + +token_aligner: + enabled: true + projection_matrix_path: "cross_tokenizer_data/projection_map_Llama-3.2_to_Phi-4-mini-instruct_multitoken_top_32_double_special.pt" + use_sparse_format: false + loss_type: "KL" + exact_token_match_only: false + temperature: 1.0 + vocab_topk: 8192 + reverse_kl: false + projection_matrix_multiplier: 1.0 + max_comb_len: 4 + learnable: false + project_teacher_to_student: false # Remove this + use_char_offset: false + use_align_fast: true + use_cuda_dp: false + dp_chunk_size: 128 + +distillation: + num_prompts_per_step: 768 + num_generations_per_prompt: 1 + max_num_steps: 80000 + max_num_epochs: 1 + val_period: 1000 + val_at_start: false + max_val_samples: 128 + val_batch_size: 64 + topk_logits_k: 8192 + use_ipc: true + loss_on_all_tokens: true + seed: 42 + # Number of CPU processes for cross-tokenizer decode/encode/align. + # Heuristic currently used: total GPUs / 2 (16 nodes * 8 GPUs = 128 -> 64 workers). + # Note: this is workload-dependent; tune per batch size and cluster shape. + cross_tokenizer_num_workers: 64 + +loss_fn: + loss_type: "KL" + temperature: 1.0 + vocab_topk: 8192 + exact_token_match_only: false + reverse_kl: false + project_teacher_to_student: false + gold_loss: true + xtoken_loss: true + ce_loss_scale: 0.1 + dynamic_loss_scaling: true + +checkpointing: + enabled: true + checkpoint_dir: "checkpoints/cross-tokenizer-distillation-llama1b-qwen8b" + metric_name: "train:loss" + higher_is_better: false + keep_top_k: 3 + save_period: 10 + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: "meta-llama/Llama-3.2-1B" + chat_template: null + train_global_batch_size: 768 + train_micro_batch_size: 1 + max_total_sequence_length: 4096 + precision: "bfloat16" + + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.02 + end_factor: 1.0 + total_iters: 4000 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 76000 + eta_min: 0.0 + - milestones: [4000] + + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: 2048 + enforce_eager: false + use_deep_gemm: false + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + distributed_executor_backend: null + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +teacher: + model_name: "microsoft/Phi-4-mini-instruct" + tokenizer: + name: "microsoft/Phi-4-mini-instruct" + chat_template: null + precision: "bfloat16" + train_global_batch_size: 768 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + max_grad_norm: 1.0 + logprob_chunk_size: null + offload_optimizer_for_logprob: false + + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + + generation: null + +data: + dataset_name: "arrow_text" + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-00[0-4][0-9][0-9]-of-02476.arrow" + # arrow_files="/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-000[0-5][0-9]-of-02476.arrow" + prompt_file: null + max_input_seq_length: 4096 + shuffle: true + +eval: + val_period: 1000 + val_at_start: false + max_val_samples: 512 + val_batch_size: 64 + max_rollout_turns: 1 + benchmarks: + math: + dataset_name: "math" + prompt_file: "examples/prompts/cot.txt" + env: + num_workers: 8 + mmlu: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" + mmlu_5shot: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + num_few_shot: 5 + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" + +logger: + log_dir: "logs/cross-tokenizer-distillation-llama1b-qwen8b" + num_val_samples_to_print: 5 + wandb_enabled: true + swanlab_enabled: false + mlflow_enabled: false + tensorboard_enabled: false + monitor_gpus: true + wandb: + project: "nemo-cross-tokenizer-distillation" + name: "cross-tokenizer-llama1b-qwen8b-bs768" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 16 diff --git a/examples/configs/dist.yaml b/examples/configs/dist.yaml new file mode 100644 index 0000000000..df541b72db --- /dev/null +++ b/examples/configs/dist.yaml @@ -0,0 +1,32 @@ +defaults: distillation_math.yaml +distillation: + num_prompts_per_step: 512 + max_num_steps: 500 + val_batch_size: 512 + val_period: 20 +loss_fn: + kl_type: reverse +checkpointing: + model_save_format: "torch_save" + keep_top_k: 3 + checkpoint_dir: checkpoints/distillation-qwen3-32b-to-4b-base-long +policy: + model_name: Qwen/Qwen3-4B-Base + train_global_batch_size: 512 + max_total_sequence_length: 20480 + generation: + colocated: + enabled: true # Try setting to false if IPC issues persist + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.7 +teacher: + model_name: Qwen/Qwen3-32B + max_total_sequence_length: 20480 +logger: + log_dir: logs/distillation-qwen3-32b-to-4b-base-long + wandb: + project: nemo-rl + name: distillation-qwen3-32b-to-4b-base-long +cluster: + num_nodes: 1 \ No newline at end of file diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index 9d7168a182..4612a436d8 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -215,7 +215,7 @@ teacher: dtensor_cfg: <<: *DTENSOR_BASE context_parallel_size: 2 - tensor_parallel_size: 4 + tensor_parallel_size: 2 data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len diff --git a/examples/configs/evals/llama_math_eval.yaml b/examples/configs/evals/llama_math_eval.yaml new file mode 100644 index 0000000000..70456cf6fc --- /dev/null +++ b/examples/configs/evals/llama_math_eval.yaml @@ -0,0 +1,17 @@ +# Math evaluation for Llama 3.2 1B (base model) +# Override generation.model_name with checkpoint path at runtime: +# generation.model_name=/path/to/checkpoint/policy/weights +defaults: "eval.yaml" + +generation: + model_name: "meta-llama/Llama-3.2-1B" + vllm_cfg: + max_model_len: 2048 + +tokenizer: + name: ${generation.model_name} + chat_template: null # base model, no chat formatting + +data: + prompt_file: "examples/prompts/cot.txt" + dataset_name: "math" diff --git a/examples/configs/evals/llama_mmlu_eval.yaml b/examples/configs/evals/llama_mmlu_eval.yaml new file mode 100644 index 0000000000..8eb64dd4c4 --- /dev/null +++ b/examples/configs/evals/llama_mmlu_eval.yaml @@ -0,0 +1,21 @@ +# MMLU evaluation for Llama 3.2 1B (base model) +# Override generation.model_name with checkpoint path at runtime: +# generation.model_name=/path/to/checkpoint/policy/weights +defaults: "eval.yaml" + +generation: + model_name: "meta-llama/Llama-3.2-1B" + vllm_cfg: + max_model_len: 2048 + +tokenizer: + name: ${generation.model_name} + chat_template: null # base model, no chat formatting + +data: + prompt_file: "examples/prompts/mmlu.txt" + dataset_name: "mmlu" + +env: + math: + verifier_type: "multilingual_multichoice" diff --git a/examples/configs/llama_off_policy_arrow.yaml b/examples/configs/llama_off_policy_arrow.yaml new file mode 100644 index 0000000000..aa9f23c4d2 --- /dev/null +++ b/examples/configs/llama_off_policy_arrow.yaml @@ -0,0 +1,202 @@ +# Unified Off-Policy Distillation Configuration (Llama 8B -> 1B) + +distillation: + num_prompts_per_step: 768 + num_generations_per_prompt: 1 + max_num_steps: 10000 + max_num_epochs: 20 + val_period: 0 # Handled in eval block instead + val_at_start: false + max_val_samples: 128 + val_batch_size: 64 + topk_logits_k: 8192 + use_ipc: true + loss_on_all_tokens: true + seed: 42 + +loss_fn: + kl_type: "forward" + zero_outside_topk: true + +checkpointing: + enabled: true + checkpoint_dir: "checkpoints/distillation-forward-kl-cosine-topk8192-16node-Llama-3.2-1B-10000steps" + metric_name: "train:loss" + higher_is_better: false + keep_top_k: 3 + save_period: 10 + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: "meta-llama/Llama-3.2-1B" + chat_template: null + train_global_batch_size: 768 + train_micro_batch_size: 1 + max_total_sequence_length: 4096 + precision: "bfloat16" + + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.02 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 9950 + eta_min: 0.0 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: 2048 + enforce_eager: false + use_deep_gemm: false + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + distributed_executor_backend: null + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +teacher: + model_name: "meta-llama/Llama-3.1-8B" + tokenizer: + name: "meta-llama/Llama-3.1-8B" + chat_template: null + precision: "bfloat16" + train_global_batch_size: 768 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 4096 + max_grad_norm: 1.0 + logprob_chunk_size: null + offload_optimizer_for_logprob: false + + dtensor_cfg: + enabled: true + _v2: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 4096 + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: false + fused: false + + generation: null + +data: + dataset_name: "arrow_text" + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-000[0-5][0-9]-of-02476.arrow" + prompt_file: null + max_input_seq_length: 4096 + shuffle: true + +eval: + val_period: 50 + val_at_start: false + max_val_samples: 512 + val_batch_size: 64 + max_rollout_turns: 1 + benchmarks: + math: + dataset_name: "math" + prompt_file: "examples/prompts/cot.txt" + env: + num_workers: 8 + mmlu: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" + mmlu_5shot: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + num_few_shot: 5 + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" + +logger: + log_dir: "logs/off-policy-distillation-arrow-llama-topk8192-16node" + num_val_samples_to_print: 5 + wandb_enabled: true + swanlab_enabled: false + mlflow_enabled: false + tensorboard_enabled: false + monitor_gpus: true + wandb: + project: "nemo-off-policy-distillation-eval" + name: "distill-arrow-llama1b-8b-bs768-lr5e5-cosine-topk8192-16node-10000steps" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 16 diff --git a/examples/configs/llama_off_policy_arrow_4node.yaml b/examples/configs/llama_off_policy_arrow_4node.yaml new file mode 100644 index 0000000000..31f947170a --- /dev/null +++ b/examples/configs/llama_off_policy_arrow_4node.yaml @@ -0,0 +1,104 @@ +defaults: distillation_math.yaml +distillation: + num_prompts_per_step: 768 + max_num_steps: 1000 + max_val_samples: 128 + val_period: 0 + loss_on_all_tokens: true + topk_logits_k: 8192 +loss_fn: + kl_type: "forward" +data: + dataset_name: "arrow_text" + # Full dataset: ".../*.arrow" (2118 files, ~296M samples — too slow to preprocess) + # Using 10 files (~1.4M samples, plenty for 1000 steps x 768 batch = 768K samples) + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-0000[0-9]-of-02476.arrow" + prompt_file: null +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: ${..model_name} + chat_template: null + train_global_batch_size: 768 + max_total_sequence_length: 4096 + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: 2048 + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + distributed_executor_backend: null + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: False + fused: False + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.02 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 950 + eta_min: 0.0 + - milestones: [50] +teacher: + model_name: "meta-llama/Llama-3.1-8B" + tokenizer: + name: ${..model_name} + chat_template: null +eval: + val_period: 50 + val_at_start: false + max_val_samples: 512 + val_batch_size: 64 + max_rollout_turns: 1 + benchmarks: + math: + dataset_name: "math" + prompt_file: "examples/prompts/cot.txt" + env: + num_workers: 8 + mmlu: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" +checkpointing: + checkpoint_dir: "checkpoints/distillation-forward-kl-cosine-topk8192-4node-${policy.model_name}" + metric_name: "train:loss" + higher_is_better: false +logger: + log_dir: logs/off-policy-distillation-arrow-llama-topk8192-4node + wandb: + project: nemo-off-policy-distillation-eval + name: distill-arrow-llama1b-8b-bs768-lr5e5-cosine-topk8192-4node +cluster: + num_nodes: 4 diff --git a/examples/configs/llama_off_policy_arrow_old.yaml b/examples/configs/llama_off_policy_arrow_old.yaml new file mode 100644 index 0000000000..6e9a2f9854 --- /dev/null +++ b/examples/configs/llama_off_policy_arrow_old.yaml @@ -0,0 +1,106 @@ +defaults: distillation_math.yaml +distillation: + num_prompts_per_step: 768 + max_num_steps: 1000 + max_val_samples: 128 + val_period: 0 + loss_on_all_tokens: true + topk_logits_k: 8192 +loss_fn: + kl_type: "forward" +data: + dataset_name: "arrow_text" + # arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/*.arrow" + # arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-00000-of-02476.arrow" + # Full dataset: ".../*.arrow" (2118 files, ~296M samples — too slow to preprocess) + # Using 10 files (~1.4M samples, plenty for 50 steps x 512 batch = 25K samples) + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-0000[0-9]-of-02476.arrow" + prompt_file: null +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: ${..model_name} + chat_template: null + train_global_batch_size: 768 + max_total_sequence_length: 4096 + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: 2048 + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + distributed_executor_backend: null + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: False + fused: False + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.02 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 950 + eta_min: 0.0 + - milestones: [50] +teacher: + model_name: "meta-llama/Llama-3.1-8B" + tokenizer: + name: ${..model_name} + chat_template: null +eval: + val_period: 50 + val_at_start: false + max_val_samples: 512 + val_batch_size: 64 + max_rollout_turns: 1 + benchmarks: + math: + dataset_name: "math" + prompt_file: "examples/prompts/cot.txt" + env: + num_workers: 8 + mmlu: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" +checkpointing: + checkpoint_dir: "checkpoints/distillation-forward-kl-cosine-topk8192-16node-${policy.model_name}" + metric_name: "train:loss" + higher_is_better: false +logger: + log_dir: logs/off-policy-distillation-arrow-llama-topk8192-16node + wandb: + project: nemo-off-policy-distillation-eval + name: distill-arrow-llama1b-8b-bs768-lr5e5-cosine-topk8192-16node +cluster: + num_nodes: 8 diff --git a/examples/configs/llama_sft_arrow.yaml b/examples/configs/llama_sft_arrow.yaml new file mode 100644 index 0000000000..f09ecf29a3 --- /dev/null +++ b/examples/configs/llama_sft_arrow.yaml @@ -0,0 +1,101 @@ +defaults: sft.yaml +sft: + max_num_steps: 1000 + max_num_epochs: 10 + val_period: 0 # loss-based validation disabled; use eval section for generation-based eval + val_batches: 8 + val_global_batch_size: 768 + val_micro_batch_size: 1 + val_at_start: false + seed: 42 +data: + dataset_name: "arrow_text" + # Full dataset: ".../*.arrow" (2118 files, ~296M samples — too slow to preprocess) + # Using 10 files (~1.4M samples, plenty for 50 steps x 512 batch = 25K samples) + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-0000[0-9]-of-02476.arrow" + prompt_file: null +policy: + model_name: "meta-llama/Llama-3.2-1B" + tokenizer: + name: ${policy.model_name} + chat_template: null # passthrough: no chat formatting (required for base models) + train_global_batch_size: 768 + max_total_sequence_length: 4096 + generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + kv_cache_dtype: "auto" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: 2048 + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + distributed_executor_backend: null + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-5 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + foreach: False + fused: False + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.02 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 950 + eta_min: 0.0 + - milestones: [50] +eval: + val_period: 50 + val_at_start: false + max_val_samples: 512 + val_batch_size: 64 + max_rollout_turns: 1 + benchmarks: + math: + dataset_name: "math" + prompt_file: "examples/prompts/cot.txt" + env: + num_workers: 8 + mmlu: + dataset_name: "mmlu" + prompt_file: "examples/prompts/mmlu.txt" + env: + num_workers: 8 + verifier_type: "multilingual_multichoice" +checkpointing: + checkpoint_dir: "results/sft-arrow-eval" + metric_name: "val:math_accuracy" + higher_is_better: true +logger: + log_dir: logs/sft-arrow-llama-eval + num_val_samples_to_print: 3 + wandb: + project: nemo-sft-arrow-eval + name: llama-1b-sft-arrow-math-mmlu +cluster: + gpus_per_node: 8 + num_nodes: 4 diff --git a/examples/configs/off_policy_distillation.yaml b/examples/configs/off_policy_distillation.yaml new file mode 100644 index 0000000000..ee62ccf2f3 --- /dev/null +++ b/examples/configs/off_policy_distillation.yaml @@ -0,0 +1,209 @@ +# Off-Policy Distillation Configuration +# +# This config is for off-policy distillation where: +# - A fixed dataset of prompt-response pairs is used (no student generation) +# - Teacher provides logits for the fixed responses +# - Student aligns with teacher using KL divergence loss +# +# Reference: https://github.com/NVIDIA-NeMo/RL/discussions/1445 +# +# Usage: +# python run_off_policy_distillation.py --config configs/off_policy_distillation.yaml +# +# For arrow dataset: +# python run_off_policy_distillation.py --config configs/off_policy_distillation.yaml \ +# data.arrow_files="/path/to/your/*.arrow" + +distillation: + num_prompts_per_step: 64 # Batch size + max_num_steps: 1000 + max_num_epochs: 10 + topk_logits_k: 64 # Top-k logits for sparse KL loss (saves memory) + seed: 42 + +loss_fn: + kl_type: "mixed" # Options: "forward", "reverse", "mixed" + mixed_kl_weight: 0.5 # Weight for forward KL when kl_type="mixed" + zero_outside_topk: false # Zero teacher logits outside top-k for forward KL + +checkpointing: + enabled: true + checkpoint_dir: "checkpoints/off-policy-distillation-${policy.model_name}" + metric_name: null # No validation in off-policy mode + higher_is_better: true + keep_top_k: 3 + save_period: 50 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: &POLICY_BASE + model_name: "Qwen/Qwen3-1.7B-Base" + tokenizer: + name: ${..model_name} + chat_template_kwargs: null + train_global_batch_size: 64 + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 8192 + precision: "bfloat16" + logprob_chunk_size: null + + offload_optimizer_for_logprob: false + + dtensor_cfg: &DTENSOR_BASE + enabled: true + _v2: true + cpu_offload: False + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 2 + context_parallel_size: 2 + custom_parallel_plan: null + + dynamic_batching: + enabled: true + train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}} + logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}} + logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + max_grad_norm: 1.0 + make_sequence_length_divisible_by: ${mul:${mul:${.dtensor_cfg.tensor_parallel_size}, ${.dtensor_cfg.context_parallel_size}}, 2} + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 2.0e-5 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + foreach: False + fused: False + + megatron_cfg: &MEGATRON_BASE + enabled: false + empty_unused_memory_level: 0 + activation_checkpointing: false + converter_type: "Qwen3ForCausalLM" + tensor_model_parallel_size: 2 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 2 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 2 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" + moe_router_bias_update_rate: 0.0 + moe_permute_fusion: false + apply_rope_fusion: True + bias_activation_fusion: True + defer_fp32_logits: False + moe_per_layer_logging: False + + optimizer: + optimizer: "adam" + lr: 2.00001e-5 + min_lr: 2.0e-5 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + sgd_momentum: 0.9 + use_distributed_optimizer: true + use_precision_aware_optimizer: true + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 1000 + lr_warmup_iters: 10 + lr_warmup_init: 2.0e-6 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 10 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [10] + + # No generation config needed for off-policy distillation + generation: null + +teacher: + <<: *POLICY_BASE + model_name: "Qwen/Qwen3-4B" + dtensor_cfg: + <<: *DTENSOR_BASE + context_parallel_size: 2 + tensor_parallel_size: 4 + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + prompt_file: null # Optional: path to prompt template + system_prompt_file: null # Optional: path to system prompt + shuffle: true + + # Arrow files dataset + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/*.arrow" + + # Alternative options (comment out arrow_files and uncomment one of these): + # train_data_path: "path/to/train.jsonl" # JSON/JSONL file + # hf_dataset: + # name: "nvidia/AceReason-1.1-SFT" + # subset: null + # split: "train" + +logger: + log_dir: "logs/off-policy-distillation" + num_val_samples_to_print: 0 # No validation in off-policy mode + wandb_enabled: true + tensorboard_enabled: true + mlflow_enabled: false + swanlab_enabled: false + monitor_gpus: true + wandb: + project: "nemo-off-policy-distillation" + name: "off-policy-${teacher.model_name}-to-${policy.model_name}" + tensorboard: + log_dir: "tb_logs-off-policy-distillation" + mlflow: + experiment_name: "off-policy-distillation" + run_name: "off-policy-distillation" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/configs/off_policy_distillation_arrow.yaml b/examples/configs/off_policy_distillation_arrow.yaml new file mode 100644 index 0000000000..3faf5dcf6c --- /dev/null +++ b/examples/configs/off_policy_distillation_arrow.yaml @@ -0,0 +1,25 @@ +defaults: distillation_math.yaml +distillation: + num_prompts_per_step: 512 + max_num_steps: 500 + loss_on_all_tokens: true # arrow text has single assistant message per row; explicit loss on all tokens +data: + dataset_name: "arrow_text" + # arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/*.arrow" + arrow_files: "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-00000-of-02476.arrow" + prompt_file: null +checkpointing: + metric_name: "val:val_loss" # off-policy has no accuracy, use val_loss instead + higher_is_better: false # lower val_loss is better +policy: + train_global_batch_size: 512 + generation: null # No student generation for off-policy +teacher: + model_name: "Qwen/Qwen3-32B" +logger: + log_dir: logs/off-policy-distillation-arrow + wandb: + project: nemo-off-policy-distillation + name: off-policy-arrow +cluster: + num_nodes: 1 diff --git a/examples/converters/consolidate_checkpoint.py b/examples/converters/consolidate_checkpoint.py new file mode 100644 index 0000000000..a1a1e3349f --- /dev/null +++ b/examples/converters/consolidate_checkpoint.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Consolidate NeMo-RL sharded safetensors checkpoint into standard HuggingFace format. + +This script reads the FSDP-sharded safetensors files saved by NeMo-RL's +AutomodelCheckpointManager and consolidates them into a single HuggingFace- +compatible directory that can be loaded by vLLM, transformers, etc. + +Usage: + python consolidate_checkpoint.py \ + --input /path/to/step_50/policy/weights \ + --output /path/to/hf_checkpoint \ + --model-name meta-llama/Llama-3.2-1B +""" + +import argparse +import json +import logging +import os +import shutil +import sys + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(REPO_ROOT) +# nemo_automodel is a workspace member, add it to the path explicitly +sys.path.append(os.path.join(REPO_ROOT, "3rdparty", "Automodel-workspace", "Automodel")) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Consolidate NeMo-RL sharded checkpoint to HuggingFace format" + ) + parser.add_argument( + "--input", + type=str, + required=True, + help="Path to policy/weights directory (contains model/ subdirectory with sharded safetensors)", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output directory for HuggingFace-format checkpoint", + ) + parser.add_argument( + "--model-name", + type=str, + required=True, + help="Original HuggingFace model name (e.g., meta-llama/Llama-3.2-1B) for tokenizer", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite output directory if it exists", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + input_model_dir = os.path.join(args.input, "model") + hf_metadata_dir = os.path.join(input_model_dir, ".hf_metadata") + + if not os.path.isdir(input_model_dir): + logger.error(f"Input model directory not found: {input_model_dir}") + sys.exit(1) + + if not os.path.isdir(hf_metadata_dir): + logger.error(f"HF metadata directory not found: {hf_metadata_dir}") + sys.exit(1) + + if os.path.exists(args.output) and not args.overwrite: + logger.error( + f"Output directory already exists: {args.output}. Use --overwrite to replace." + ) + sys.exit(1) + + os.makedirs(args.output, exist_ok=True) + + # Step 1: Read the fqn-to-file-index mapping + mapping_path = os.path.join(hf_metadata_dir, "fqn_to_file_index_mapping.json") + with open(mapping_path, "r") as f: + fqn_to_index_mapping = json.load(f) + logger.info( + f"Loaded fqn_to_file_index_mapping with {len(fqn_to_index_mapping)} parameters" + ) + + # Step 2: Consolidate sharded safetensors into standard HuggingFace format + from nemo_automodel.components.checkpoint._backports.consolidate_hf_safetensors import ( + consolidate_safetensors_files, + ) + + logger.info(f"Consolidating shards from {input_model_dir} -> {args.output}") + consolidate_safetensors_files( + input_dir=input_model_dir, + output_dir=args.output, + fqn_to_index_mapping=fqn_to_index_mapping, + num_threads=4, + ) + logger.info("Consolidation complete") + + # Step 3: Copy config.json and generation_config.json from .hf_metadata + for config_file in ["config.json", "generation_config.json"]: + src = os.path.join(hf_metadata_dir, config_file) + dst = os.path.join(args.output, config_file) + if os.path.exists(src): + shutil.copy2(src, dst) + logger.info(f"Copied {config_file}") + + # Step 4: Save tokenizer from the original HuggingFace model + logger.info(f"Saving tokenizer from {args.model_name}") + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) + tokenizer.save_pretrained(args.output) + logger.info(f"Tokenizer saved") + + # Verify output + output_files = os.listdir(args.output) + logger.info(f"Output directory contents: {sorted(output_files)}") + logger.info(f"Done! HuggingFace checkpoint saved to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/dummy/data_inspect.py b/examples/dummy/data_inspect.py new file mode 100644 index 0000000000..e72e14353e --- /dev/null +++ b/examples/dummy/data_inspect.py @@ -0,0 +1,111 @@ +import os +import random +import re +from datasets import Dataset # using Hugging Face datasets library + +# --- CONFIGURATION --- +# The path you found earlier +ARROW_FILE_PATH = "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-01090-of-02476.arrow" +NUM_SAMPLES = 20 + +def parse_and_clean(text): + """ + Tries to extract (input, output) pairs. Returns a dict or None. + """ + if not text or len(text) < 50: return None + + # STRATEGY 1: Explicit Chat Format () + if "User" in text and "Assistant" in text: + try: + parts = text.split("Assistant") + return { + "type": "Chat (Clean)", + "input": parts[0].replace("User", "").strip(), + "output": parts[1].strip() + } + except: pass + + # STRATEGY 2: Multi-turn Dialogue (Participant/Student/Teacher) + # Regex for lines starting with "**Name:**" + speaker_pattern = re.compile(r'\n\*\*(.*?):\*\*') + matches = list(speaker_pattern.finditer(text)) + + if len(matches) >= 2: + last_turn_start = matches[-1].start() + return { + "type": "Dialogue (Multi-turn)", + "input": text[:last_turn_start].strip(), + "output": text[last_turn_start:].strip() + } + + # STRATEGY 3: START/END Blocks (Math/Reasoning) + if "START:" in text and "END:" in text: + try: + content = text.split("START:", 1)[1].split("END:", 1)[0].strip() + + if "Question:" in content or "Problem:" in content: + if "Answer:" in content: + parts = content.split("Answer:", 1) + return {"type": "Math/Reasoning", "input": parts[0].strip(), "output": "Answer: " + parts[1].strip()} + elif "Solution:" in content: + parts = content.split("Solution:", 1) + return {"type": "Math/Reasoning", "input": parts[0].strip(), "output": "Solution: " + parts[1].strip()} + + return {"type": "Raw Text (Matched START/END but Rejected)", "input": "SKIPPED", "output": "SKIPPED"} + except: pass + + return None + +def main(): + if not os.path.exists(ARROW_FILE_PATH): + print(f"ERROR: File not found at {ARROW_FILE_PATH}") + return + + print(f"--- Inspecting {ARROW_FILE_PATH} ---") + + # Load dataset using Hugging Face Library + try: + # Try loading specific file + ds = Dataset.from_file(ARROW_FILE_PATH) + # If the arrow file has no metadata, sometimes we need to just select the column + all_text = ds["text"] + except Exception as e: + print(f"Error loading file with HF datasets: {e}") + return + + # Randomly sample + samples = random.sample(all_text, min(NUM_SAMPLES, len(all_text))) + + print(f"Total Rows in File: {len(all_text)}") + print(f"Inspecting {len(samples)} random samples...\n") + print("="*60) + + stats = {"kept": 0, "skipped": 0} + + for i, text in enumerate(samples): + result = parse_and_clean(text) + + print(f"\n[Sample {i+1}] Length: {len(text)}") + + if result: + if result['input'] == "SKIPPED": + print(f"Status: SKIPPED (Raw Text)") + stats["skipped"] += 1 + # print(f"Preview: {text[:100].replace(chr(10), ' ')}...") + else: + print(f"Status: KEPT ({result['type']})") + stats["kept"] += 1 + print(f"INPUT PREVIEW: {result['input'][:100].replace(chr(10), ' ')}...") + print(f"OUTPUT PREVIEW: {result['output'][:100].replace(chr(10), ' ')}...") + else: + print(f"Status: SKIPPED (Unknown Format)") + stats["skipped"] += 1 + print(f"Preview: {text[:100].replace(chr(10), ' ')}...") + + print("-" * 30) + + print("\n" + "="*60) + print(f"SUMMARY: Kept {stats['kept']} / Skipped {stats['skipped']}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dummy/data_inspect_v2.py b/examples/dummy/data_inspect_v2.py new file mode 100644 index 0000000000..f74174b57c --- /dev/null +++ b/examples/dummy/data_inspect_v2.py @@ -0,0 +1,92 @@ +import os +import random +import re +from datasets import Dataset + +# --- CONFIGURATION --- +ARROW_FILE_PATH = "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/data-01090-of-02476.arrow" +NUM_SAMPLES = 20 + +def parse_and_clean(text): + if not text or len(text) < 50: return None + + # --- 1. Standard Chat (High Quality) --- + if "User" in text and "Assistant" in text: + parts = text.split("Assistant") + return {"type": "Chat (NeMo)", "input": parts[0].replace("User", "").strip(), "output": parts[1].strip()} + + # --- 2. Input/Output Format (Alpaca Style) --- + # Case insensitive check for "input:" at start + if text.lower().startswith("input:"): + # Look for "output:" or "response:" + if "output:" in text.lower(): + parts = re.split(r'output:', text, flags=re.IGNORECASE, maxsplit=1) + return {"type": "Input/Output", "input": parts[0].replace("input:", "").strip(), "output": parts[1].strip()} + + return {"type": "PARTIAL (Found input, no output)", "input": "DEBUG", "output": "DEBUG"} + + # --- 3. Question/Answer Format --- + if "Question:" in text: + # Check for Answer or Solution + if "Answer:" in text: + parts = text.split("Answer:", 1) + return {"type": "Q&A (Answer)", "input": parts[0].strip(), "output": parts[1].strip()} + elif "Solution:" in text: + parts = text.split("Solution:", 1) + return {"type": "Q&A (Solution)", "input": parts[0].strip(), "output": parts[1].strip()} + + # --- 4. Multi-turn Dialogue (Speaker names) --- + speaker_pattern = re.compile(r'\n\*\*(.*?):\*\*') + matches = list(speaker_pattern.finditer(text)) + if len(matches) >= 2: + last_turn_start = matches[-1].start() + return {"type": "Dialogue (Multi-turn)", "input": text[:last_turn_start].strip(), "output": text[last_turn_start:].strip()} + + # --- 5. Raw Instruction Heuristic (Coding/Math) --- + # If it starts with an instruction verb but has no clear tags + start_words = ["Given", "Write", "Create", "Implement", "Calculate", "Imagine"] + if any(text.startswith(w) for w in start_words): + # Heuristic: If there is a code block or "Solution:", split there + if "Solution:" in text: + parts = text.split("Solution:", 1) + return {"type": "Implicit Instruction", "input": parts[0].strip(), "output": parts[1].strip()} + + return None + +def main(): + print(f"--- Deep Mining {ARROW_FILE_PATH} ---") + try: + ds = Dataset.from_file(ARROW_FILE_PATH) + all_text = ds["text"] + except Exception as e: + print(f"Error: {e}") + return + + samples = random.sample(list(all_text), min(NUM_SAMPLES, len(all_text))) + + stats = {"kept": 0, "skipped": 0, "partial": 0} + + for i, text in enumerate(samples): + result = parse_and_clean(text) + print(f"\n[Sample {i+1}]") + + if result: + if result['type'].startswith("PARTIAL"): + print(f"\033[93m{result['type']}\033[0m") # Yellow + print(f"Full Text Preview: {text[:150].replace(chr(10), ' ')}...") + stats["partial"] += 1 + else: + print(f"\033[92mKEPT ({result['type']})\033[0m") # Green + print(f"IN: {result['input'][:80].replace(chr(10), ' ')}...") + print(f"OUT: {result['output'][:80].replace(chr(10), ' ')}...") + stats["kept"] += 1 + else: + print(f"\033[91mSKIPPED (Raw/Article)\033[0m") # Red + print(f"Preview: {text[:100].replace(chr(10), ' ')}...") + stats["skipped"] += 1 + + print("\n" + "="*40) + print(f"SUMMARY: Kept {stats['kept']} | Partial {stats['partial']} | Skipped {stats['skipped']}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/dummy/inspect_random_shards.py b/examples/dummy/inspect_random_shards.py new file mode 100644 index 0000000000..58c28908af --- /dev/null +++ b/examples/dummy/inspect_random_shards.py @@ -0,0 +1,89 @@ +import os +import random +import glob +import re +from datasets import Dataset + +# --- CONFIGURATION --- +# The directory containing all your arrow files +DATA_DIR = "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/" +FILES_TO_INSPECT = 3 +SAMPLES_PER_FILE = 5 + +def parse_and_clean(text): + if not text or len(text) < 50: return None + + # --- 1. Standard Chat () --- + if "User" in text and "Assistant" in text: + try: + parts = text.split("Assistant") + return {"type": "Chat (NeMo)", "input": parts[0].replace("User", "").strip(), "output": parts[1].strip()} + except: pass + + # --- 2. Input/Output (Alpaca Style) --- + if text.lower().startswith("input:"): + if "output:" in text.lower(): + parts = re.split(r'output:', text, flags=re.IGNORECASE, maxsplit=1) + return {"type": "Input/Output", "input": parts[0].replace("input:", "").strip(), "output": parts[1].strip()} + + # --- 3. Question/Answer & Problem/Solution --- + if "Question:" in text or "Problem:" in text or text.startswith("Solve "): + if "Answer:" in text: + parts = text.split("Answer:", 1) + return {"type": "Q&A (Answer)", "input": parts[0].strip(), "output": parts[1].strip()} + elif "Solution:" in text: + parts = text.split("Solution:", 1) + return {"type": "Q&A (Solution)", "input": parts[0].strip(), "output": parts[1].strip()} + + # --- 4. Multi-turn Dialogue --- + speaker_pattern = re.compile(r'\n\*\*(.*?):\*\*') + matches = list(speaker_pattern.finditer(text)) + if len(matches) >= 2: + last_turn_start = matches[-1].start() + return {"type": "Dialogue (Multi-turn)", "input": text[:last_turn_start].strip(), "output": text[last_turn_start:].strip()} + + return None + +def main(): + print(f"--- Scanning Directory: {DATA_DIR} ---") + + # Get all arrow files + all_files = glob.glob(os.path.join(DATA_DIR, "*.arrow")) + print(f"Found {len(all_files)} files total.") + + if len(all_files) == 0: + print("Error: No files found. Check path.") + return + + # Pick random files + selected_files = random.sample(all_files, min(FILES_TO_INSPECT, len(all_files))) + + for file_path in selected_files: + print("\n" + "="*80) + print(f"INSPECTING FILE: {os.path.basename(file_path)}") + print("="*80) + + try: + ds = Dataset.from_file(file_path) + # Use streaming-like access to avoid loading everything if possible + # or just take the first N + indices = random.sample(range(len(ds)), min(SAMPLES_PER_FILE, len(ds))) + samples = [ds[i]['text'] for i in indices] + except Exception as e: + print(f"Error reading file: {e}") + continue + + for i, text in enumerate(samples): + result = parse_and_clean(text) + print(f"\n[Sample {i+1}]") + + if result: + print(f"\033[92mKEPT ({result['type']})\033[0m") # Green + print(f"IN: {result['input'][:100].replace(chr(10), ' ')}...") + print(f"OUT: {result['output'][:100].replace(chr(10), ' ')}...") + else: + print(f"\033[91mSKIPPED (Raw/Article)\033[0m") # Red + print(f"Preview: {text[:100].replace(chr(10), ' ')}...") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/run_distillation_math.py b/examples/run_distillation_math.py new file mode 100644 index 0000000000..9674e40cdd --- /dev/null +++ b/examples/run_distillation_math.py @@ -0,0 +1,236 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from collections import defaultdict +from typing import Any, Optional +import glob +import torch +from torch.utils.data import Dataset as TorchDataset +from datasets import load_dataset +from torchdata.stateful_dataloader import StatefulDataLoader + +from omegaconf import OmegaConf +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.distillation import MasterConfig, distillation_train, offpolicy_distillation_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset +from nemo_rl.data.interfaces import ( + TaskDataProcessFnCallable, + TaskDataSpec, +) +from nemo_rl.data.processors import math_hf_data_processor +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run distillation training with configuration" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# =============================================================================== +# Math Data Processor +# =============================================================================== +TokenizerType = PreTrainedTokenizerBase + + +def setup_data( + tokenizer: TokenizerType, + data_config: DataConfig, + env_configs: dict[str, Any], + seed: int, +) -> tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], +]: + print("\n▶ Setting up data...") + math_task_spec = TaskDataSpec( + task_name="math", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + + # load dataset + data: Any = load_response_dataset(data_config, seed) + task_name = ( + data.task_name if hasattr(data, "task_name") else data.task_spec.task_name + ) + # data processor + task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( + defaultdict(lambda: (math_task_spec, math_hf_data_processor)) + ) + task_data_processors[task_name] = (math_task_spec, math_hf_data_processor) + + # setup math environment + math_env = MathEnvironment.options( # type: ignore # it's wrapped with ray.remote + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ), + "env_vars": dict(os.environ), # Pass thru all user environment variables + } + ).remote(env_configs["math"]) + + dataset = AllTaskProcessedDataset( + data.formatted_ds["train"], + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset: Optional[AllTaskProcessedDataset] = None + if data.formatted_ds["validation"]: + val_dataset = AllTaskProcessedDataset( + data.formatted_ds["validation"], + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + else: + val_dataset = None + + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) + task_to_env[task_name] = math_env + return dataset, val_dataset, task_to_env, task_to_env + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "distillation_math.yaml" + ) + + config = load_config(args.config) + if overrides: + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + + init_ray() + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + if config["policy"]["generation"] is not None: + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + else: + print(" ⚠️ No generation config found, this may cause issues") + + # setup data + ( + dataset, + val_dataset, + task_to_env, + val_task_to_env, + ) = setup_data(tokenizer, config["data"], config["env"], 42) + + ( + student_policy, + teacher_policy, + student_generation, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + distillation_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + distillation_train( + student_policy, + teacher_policy, + student_generation, + dataloader, + val_dataloader, + tokenizer, # pass tokenizer parameter + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + distillation_state, + master_config, + ) + + # Initialize Dataset + print("\n▶ Initializing Off-Policy Dataset...") + train_dataset = OffPolicyDistillationDataset( + data_path='/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/*.arrow', + tokenizer=tokenizer, + max_seq_length=master_config["data"]["max_input_seq_length"], + ) + + # Initialize Dataloader + train_dataloader = StatefulDataLoader( + dataset=train_dataset, + batch_size=master_config["distillation"]["batch_size"], + collate_fn=offpolicy_collate_fn, + shuffle=True, + num_workers=4, + pin_memory=True + ) + + # Now call your training function + offpolicy_distillation_train( + student_policy=student_policy, + teacher_policy=teacher_policy, + dataloader=train_dataloader, # Pass the new dataloader here + val_dataloader=None, + tokenizer=tokenizer, + loss_fn=loss_fn, + logger=logger, + checkpointer=checkpointer, + distillation_save_state=distillation_state, + master_config=master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_off_policy_distillation.py b/examples/run_off_policy_distillation.py new file mode 100644 index 0000000000..2d9e023a76 --- /dev/null +++ b/examples/run_off_policy_distillation.py @@ -0,0 +1,371 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Off-Policy Distillation Training Script + +This script runs off-policy distillation where: +- A fixed dataset of prompt-response pairs is used (no student generation) +- Teacher provides logits for the fixed responses +- Student aligns with teacher using KL divergence loss + +Usage: + python run_off_policy_distillation.py --config configs/off_policy_distillation.yaml + +For your arrow dataset: + python run_off_policy_distillation.py --config configs/off_policy_distillation.yaml \ + data.arrow_files="/path/to/your/*.arrow" + +Reference: https://github.com/NVIDIA-NeMo/RL/discussions/1445 +""" + +import argparse +import glob +import os +from functools import partial +from typing import Any, Optional + +import torch +from datasets import Dataset, load_dataset +from omegaconf import OmegaConf +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.off_policy_distillation import ( + OffPolicyMasterConfig, + off_policy_distillation_train, + setup, +) +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run off-policy distillation training" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +# =============================================================================== +# Data Processing for Off-Policy Distillation +# =============================================================================== + + +def off_policy_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: PreTrainedTokenizerBase, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """ + Process a datum dictionary for off-policy distillation. + + This processor handles datasets with prompt-response pairs where the response + is already provided. It creates message_log with token_ids and loss masks. + + Supports multiple input formats: + 1. {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} + 2. {"conversations": [{"from": "human/gpt", "value": "..."}]} # ShareGPT + 3. {"prompt": "...", "response": "..."} + 4. {"input": "...", "output": "..."} + 5. {"instruction": "...", "output": "..."} # Alpaca + 6. {"text": "..."} # Full text - train on all tokens (language modeling style) + """ + + # Special handling for raw text format (no chat structure) + if "text" in datum_dict and len(datum_dict.keys()) == 1: + # Raw text format - tokenize directly without chat template + # Train on all tokens (language modeling / SFT style) + text = datum_dict["text"] + + # Add BOS token if tokenizer has one + if tokenizer.bos_token: + text = tokenizer.bos_token + text + + token_ids = tokenizer( + text, + return_tensors="pt", + add_special_tokens=False, + truncation=True, + max_length=max_seq_length, + )["input_ids"][0] + + # Train on all tokens + token_loss_mask = torch.ones_like(token_ids) + + length = len(token_ids) + loss_multiplier = 1.0 if length <= max_seq_length else 0.0 + + message_log: LLMMessageLogType = [{ + "role": "assistant", # Mark as assistant so loss is computed + "content": text[:500] + "..." if len(text) > 500 else text, + "token_ids": token_ids, + "token_loss_mask": token_loss_mask, + }] + + return { + "message_log": message_log, + "length": length, + "extra_env_info": {}, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": "off_policy_distillation", + } + + # Handle chat-structured formats + messages = None + + if "messages" in datum_dict: + messages = datum_dict["messages"] + elif "conversations" in datum_dict: + # ShareGPT format + messages = [] + for conv in datum_dict["conversations"]: + role_from = conv.get("from", conv.get("role", "")) + if role_from in ["gpt", "assistant", "model", "chatbot"]: + role = "assistant" + elif role_from in ["system"]: + role = "system" + else: + role = "user" + content = conv.get("value", conv.get("content", "")) + messages.append({"role": role, "content": content}) + elif "prompt" in datum_dict and "response" in datum_dict: + messages = [ + {"role": "user", "content": datum_dict["prompt"]}, + {"role": "assistant", "content": datum_dict["response"]}, + ] + elif "input" in datum_dict and "output" in datum_dict: + messages = [ + {"role": "user", "content": datum_dict["input"]}, + {"role": "assistant", "content": datum_dict["output"]}, + ] + elif "instruction" in datum_dict: + user_content = datum_dict["instruction"] + if "input" in datum_dict and datum_dict["input"]: + user_content = f"{user_content}\n\n{datum_dict['input']}" + response = datum_dict.get("output", datum_dict.get("response", "")) + messages = [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": response}, + ] + elif "text" in datum_dict: + # Text with other keys - treat as assistant response + messages = [{"role": "assistant", "content": datum_dict["text"]}] + else: + raise ValueError( + f"Unsupported datum format. Expected: messages, conversations, " + f"prompt/response, input/output, instruction/output, or text. " + f"Got keys: {list(datum_dict.keys())}" + ) + + # Add system prompt if specified + if task_data_spec.system_prompt: + messages = [{"role": "system", "content": task_data_spec.system_prompt}] + messages + + # Build message_log with tokenization + message_log: LLMMessageLogType = [] + + for i, msg in enumerate(messages): + role = msg["role"] + content = msg["content"] + + # Apply prompt template for user messages + if role == "user" and task_data_spec.prompt: + content = task_data_spec.prompt.format(content) + + # Add generation prompt only for last user message before assistant + add_gen_prompt = ( + role == "user" + and i + 1 < len(messages) + and messages[i + 1]["role"] == "assistant" + ) + + # Tokenize + chat_msg = [{"role": role, "content": content}] + formatted = tokenizer.apply_chat_template( + chat_msg, + tokenize=False, + add_generation_prompt=add_gen_prompt, + add_special_tokens=(i == 0), + ) + + token_ids = tokenizer( + formatted, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + # Loss mask: 1 for assistant, 0 for others + if role == "assistant": + token_loss_mask = torch.ones_like(token_ids) + else: + token_loss_mask = torch.zeros_like(token_ids) + + message_log.append({ + "role": role, + "content": formatted, + "token_ids": token_ids, + "token_loss_mask": token_loss_mask, + }) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + for message in message_log: + max_per_msg = max(4, max_seq_length // len(message_log)) + message["token_ids"] = message["token_ids"][:max_per_msg] + message["token_loss_mask"] = message["token_loss_mask"][:max_per_msg] + loss_multiplier = 0.0 + + return { + "message_log": message_log, + "length": length, + "extra_env_info": datum_dict.get("extra_env_info", {}), + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict.get("task_name", "off_policy_distillation"), + } + + +def load_arrow_dataset( + data_files: list[str], + val_split: float = 0.0, + seed: int = 42, +) -> tuple[Dataset, Optional[Dataset]]: + """Load dataset from arrow files.""" + print(f"Loading {len(data_files)} arrow files...") + dataset = load_dataset("arrow", data_files=data_files, split="train") + print(f" ✓ Loaded {len(dataset)} samples") + + if val_split > 0: + split = dataset.train_test_split(test_size=val_split, seed=seed) + return split["train"], split["test"] + return dataset, None + + +def setup_data( + tokenizer: PreTrainedTokenizerBase, + data_config: DataConfig, + seed: int = 42, +) -> AllTaskProcessedDataset: + """Setup data for off-policy distillation.""" + print("\n▶ Setting up data for off-policy distillation...") + + task_spec = TaskDataSpec( + task_name="off_policy_distillation", + prompt_file=data_config.get("prompt_file"), + system_prompt_file=data_config.get("system_prompt_file"), + ) + + # Load dataset based on format + if "arrow_files" in data_config: + arrow_files = data_config["arrow_files"] + if isinstance(arrow_files, str): + arrow_files = glob.glob(arrow_files) + train_ds, _ = load_arrow_dataset(arrow_files, seed=seed) + elif "train_data_path" in data_config: + train_ds = load_dataset( + "json", data_files=data_config["train_data_path"], split="train" + ) + elif "hf_dataset" in data_config: + hf_config = data_config["hf_dataset"] + train_ds = load_dataset( + hf_config["name"], + hf_config.get("subset"), + split=hf_config.get("split", "train"), + ) + else: + raise ValueError( + "Data config must have: 'arrow_files', 'train_data_path', or 'hf_dataset'" + ) + + # Create processed dataset + train_dataset = AllTaskProcessedDataset( + train_ds, + tokenizer, + task_spec, + off_policy_data_processor, + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset + + +def main() -> None: + """Main entry point.""" + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "off_policy_distillation.yaml" + ) + + config = load_config(args.config) + if overrides: + config = parse_hydra_overrides(config, overrides) + + config: OffPolicyMasterConfig = OmegaConf.to_container(config, resolve=True) + + # Get experiment directory + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + + init_ray() + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Setup data + train_dataset = setup_data( + tokenizer, config["data"], config["distillation"]["seed"] + ) + + # Setup and run training + ( + student_policy, + teacher_policy, + dataloader, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + ) = setup(config, tokenizer, train_dataset) + + off_policy_distillation_train( + student_policy, + teacher_policy, + dataloader, + tokenizer, + loss_fn, + logger, + checkpointer, + save_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_off_policy_distillation_arrow.py b/examples/run_off_policy_distillation_arrow.py new file mode 100644 index 0000000000..37ec45e837 --- /dev/null +++ b/examples/run_off_policy_distillation_arrow.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Off-Policy Distillation with Arrow Dataset Support + +This script runs off-policy distillation using the same data loading +pattern as run_sft.py (load_response_dataset + AllTaskProcessedDataset). + +Arrow files are loaded via the existing ArrowTextDataset class by setting +dataset_name: "arrow_text" in the config. ArrowTextDataset handles: + - Loading .arrow files via glob patterns + - Wrapping text as {"messages": [{"role": "assistant", "content": }]} + - Splitting into train/validation + +Off-policy: no student generation, teacher provides logits for fixed responses. +""" + +import argparse +import os +import pprint +from functools import partial +from typing import Any, Optional, Callable + +from omegaconf import OmegaConf +from transformers import AutoTokenizer + +from nemo_rl.algorithms.off_policy_distillation import ( + OffPolicyMasterConfig, + off_policy_distillation_train, + setup, +) +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec +from nemo_rl.data.llm_message_utils import get_formatted_message_log +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) +OmegaConf.register_new_resolver("max", lambda a, b: max(a, b)) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run off-policy distillation with Arrow dataset support" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + args, overrides = parser.parse_known_args() + return args, overrides + + +# ======================================================= +# Data Processing (following run_sft.py pattern) +# ======================================================= +def sft_preprocessor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False, + datum_preprocessor: Optional[Callable] = None, +) -> DatumSpec: + """Process a datum dictionary for off-policy distillation. + + Same as run_sft.py's sft_preprocessor. ArrowTextDataset already wraps + plain text into messages format, so we only need to handle messages here. + """ + # optional preprocessor + if datum_preprocessor is not None: + datum_dict = datum_preprocessor(datum_dict) + + message_log = get_formatted_message_log( + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, + add_generation_prompt=add_generation_prompt, + tools=datum_dict.get("tools", None), # Pass tools from data if present + ) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output = { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + + +def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): + """Setup data for off-policy distillation. + + Uses load_response_dataset exactly like run_sft.py. For arrow files, + set dataset_name: "arrow_text" in the config and ArrowTextDataset + handles the rest. + """ + print("\n▶ Setting up data...") + + # load dataset + data = load_response_dataset(data_config, seed) + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + task_spec = data.task_spec + print( + f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset)} samples, respectively." + ) + + train_dataset = AllTaskProcessedDataset( + train_dataset, + tokenizer, + task_spec, + partial( + sft_preprocessor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + add_generation_prompt=data_config.get("add_generation_prompt", False), + ), + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + task_spec, + partial( + sft_preprocessor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + add_generation_prompt=data_config.get("add_generation_prompt", False), + ), + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset, val_dataset, task_spec + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "off_policy_distillation_math.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: OffPolicyMasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Setup data + dataset, val_dataset, task_spec = setup_data( + tokenizer, config["data"], config["distillation"]["seed"] + ) + + # ---------- quick dataloader sanity check ---------- + # from torchdata.stateful_dataloader import StatefulDataLoader + # from nemo_rl.data.collate_fn import rl_collate_fn + # batch_size = config["distillation"]["num_prompts_per_step"] + # _dl = StatefulDataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=rl_collate_fn, drop_last=True) + # print(f"\nDataloader: {len(dataset)} samples, {len(_dl)} batches (bs={batch_size})") + # for _i, _b in enumerate(_dl): + # if _i >= 3: + # break + # print(f" batch {_i}: lengths={_b['length'].tolist()}, loss_mult={_b['loss_multiplier'].tolist()}") + # _m = _b['message_log'][0][0] + # print(f" sample[0]: role={_m['role']}, tok[:10]={_m['token_ids'][:10].tolist()}") + # print(f" text[:100]: {tokenizer.decode(_m['token_ids'][:30], skip_special_tokens=False)[:100]}") + # print("\nDataloader OK") + # import sys; sys.exit(0) + # --------------------------------------------------- + + # Setup off-policy distillation (no student generation needed) + ( + student_policy, + teacher_policy, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + distillation_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + # Run off-policy distillation training + off_policy_distillation_train( + student_policy, + teacher_policy, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + logger, + checkpointer, + distillation_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_off_policy_distillation_arrow_with_eval.py b/examples/run_off_policy_distillation_arrow_with_eval.py new file mode 100644 index 0000000000..76508e83f9 --- /dev/null +++ b/examples/run_off_policy_distillation_arrow_with_eval.py @@ -0,0 +1,672 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Off-policy distillation on arrow data with inline MATH/MMLU evaluation. + +Extends run_off_policy_distillation_arrow.py with periodic generation-based +evaluation (MATH, MMLU) using a colocated vLLM generation engine, following +the same pattern as run_sft_arrow_with_eval.py. + +Usage: + uv run examples/run_off_policy_distillation_arrow_with_eval.py \ + --config examples/configs/llama_off_policy_arrow.yaml +""" + +import argparse +import os +import sys +import pprint + +# Force unbuffered stdout/stderr so logs appear immediately in SLURM output files +os.environ["PYTHONUNBUFFERED"] = "1" +if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(line_buffering=True) +if hasattr(sys.stderr, "reconfigure"): + sys.stderr.reconfigure(line_buffering=True) +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Callable, Optional, cast + +import torch +from omegaconf import OmegaConf +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from nemo_rl.algorithms.loss_functions import ( + CrossTokenizerDistillationLossFn, + DistillationLossFn, +) +from nemo_rl.algorithms.off_policy_distillation import ( + OffPolicyDistillationSaveState, + OffPolicyMasterConfig, + _default_distillation_save_state, + check_vocab_equality, + off_policy_distillation_train, +) +from nemo_rl.algorithms.utils import set_seed +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_eval_dataset +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec +from nemo_rl.data.llm_message_utils import get_keys_from_message_log +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster, init_ray +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.experience.rollouts import run_multi_turn_rollout +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.models.generation.interfaces import GenerationInterface +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointManager +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import Logger, get_next_experiment_dir, print_message_log_samples +from nemo_rl.utils.timer import Timer + +import ray + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b, replace=True) +OmegaConf.register_new_resolver("max", lambda a, b: max(a, b), replace=True) + + +# ========================================================================= +# CLI +# ========================================================================= +def parse_args(): + parser = argparse.ArgumentParser( + description="Off-policy distillation on arrow data with inline MATH/MMLU eval" + ) + parser.add_argument("--config", type=str, default=None, help="Path to YAML config") + args, overrides = parser.parse_known_args() + return args, overrides + + +# ========================================================================= +# Arrow-text training data (reused from run_off_policy_distillation_arrow.py) +# ========================================================================= +def _sft_preprocessor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False, + datum_preprocessor: Optional[Callable] = None, +) -> DatumSpec: + from nemo_rl.data.llm_message_utils import get_formatted_message_log + + if datum_preprocessor is not None: + datum_dict = datum_preprocessor(datum_dict) + + message_log = get_formatted_message_log( + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, + add_generation_prompt=add_generation_prompt, + tools=datum_dict.get("tools", None), + ) + + length = sum(len(m["token_ids"]) for m in message_log) + loss_multiplier = 1.0 + if length > max_seq_length: + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + return { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + + +def _kd_preprocessor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Knowledge-distillation preprocessor: raw text, no chat template, loss on all tokens. + + Both student and teacher tokenize the same raw text, matching the + train_distillation_ddp.py pipeline. The raw text is stored in + extra_env_info so the teacher can tokenize it directly. + """ + raw_text = "\n".join( + msg["content"] + for msg in datum_dict["messages"] + if isinstance(msg.get("content"), str) + ) + + token_ids = tokenizer( + raw_text, + return_tensors="pt", + add_special_tokens=True, + max_length=max_seq_length, + truncation=True, + )["input_ids"][0] + + length = len(token_ids) + loss_multiplier = 1.0 + if length > max_seq_length: + loss_multiplier = 0.0 + + message_log = [ + { + "role": "assistant", + "content": raw_text, + "token_ids": token_ids, + "token_loss_mask": torch.ones_like(token_ids), + } + ] + + return { + "message_log": message_log, + "length": length, + "extra_env_info": {"raw_text": raw_text}, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + + +def setup_train_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): + from nemo_rl.data.datasets import load_response_dataset + + print("\n▶ Setting up training data...") + data = load_response_dataset(data_config, seed) + train_dataset_raw = data.formatted_ds["train"] + task_spec = data.task_spec + + train_dataset = AllTaskProcessedDataset( + train_dataset_raw, + tokenizer, + task_spec, + _kd_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + print(f" ✓ Training dataset loaded with {len(train_dataset)} samples") + return train_dataset, task_spec + + +# ========================================================================= +# Eval data + environments (from run_sft_arrow_with_eval.py) +# ========================================================================= +def setup_eval_data( + tokenizer: AutoTokenizer, + eval_config: dict[str, Any], + max_seq_length: int, +) -> tuple[ + dict[str, StatefulDataLoader], + dict[str, dict[str, EnvironmentInterface]], +]: + print("\n▶ Setting up evaluation benchmarks...") + eval_dataloaders: dict[str, StatefulDataLoader] = {} + eval_envs: dict[str, dict[str, EnvironmentInterface]] = {} + + for bench_name, bench_cfg in eval_config["benchmarks"].items(): + dataset_name = bench_cfg["dataset_name"] + prompt_file = bench_cfg.get("prompt_file") + system_prompt_file = bench_cfg.get("system_prompt_file") + env_cfg = bench_cfg.get("env", {"num_workers": 8}) + + data_cfg = { + "dataset_name": dataset_name, + "prompt_file": prompt_file, + "system_prompt_file": system_prompt_file, + "num_few_shot": bench_cfg.get("num_few_shot", 0), + } + base_dataset = load_eval_dataset(data_cfg) + + task_spec = TaskDataSpec( + task_name=dataset_name, + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + + dataset = AllTaskProcessedDataset( + dataset=base_dataset.rekeyed_ds, + tokenizer=tokenizer, + default_task_data_spec=task_spec, + task_data_processors=base_dataset.processor, + max_seq_length=max_seq_length, + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=eval_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + ) + + math_env = MathEnvironment.options( + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ), + "env_vars": dict(os.environ), + } + ).remote(env_cfg) + + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) + task_to_env[dataset_name] = math_env + task_to_env[None] = math_env + + eval_dataloaders[bench_name] = dataloader + eval_envs[bench_name] = task_to_env + print(f" ✓ {bench_name}: {len(dataset)} samples, env={dataset_name}") + + return eval_dataloaders, eval_envs + + +# ========================================================================= +# Generation-based validation (from run_sft_arrow_with_eval.py) +# ========================================================================= +def gen_validate( + generation: GenerationInterface, + eval_dataloaders: dict[str, StatefulDataLoader], + eval_envs: dict[str, dict[str, EnvironmentInterface]], + eval_config: dict[str, Any], + master_config: dict[str, Any], + step: int, + tokenizer: PreTrainedTokenizerBase | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + timer = Timer() + all_val_metrics: dict[str, Any] = {} + + max_val_samples = eval_config.get("max_val_samples", 512) + val_batch_size = eval_config["val_batch_size"] + max_batches = max_val_samples // val_batch_size + max_rollout_turns = eval_config.get("max_rollout_turns", 1) + max_seq_len = master_config["policy"]["max_total_sequence_length"] + + with timer.time("total_eval_time"): + for bench_name, dataloader in eval_dataloaders.items(): + print(f"\n▶ Evaluating {bench_name} at step {step}...", flush=True) + total_rewards = [] + total_lengths = [] + all_message_logs = [] + + for batch_idx, val_batch in enumerate(dataloader): + if batch_idx >= max_batches: + break + + val_batch, gen_metrics = run_multi_turn_rollout( + generation, + val_batch, + tokenizer, + eval_envs[bench_name], + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + greedy=True, + ) + + rewards = val_batch["total_reward"] + total_rewards.extend(rewards.tolist()) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + + to_env = [ + get_keys_from_message_log( + val_batch["message_log"][i], ["role", "content"] + ) + for i in range(len(val_batch["message_log"])) + ] + all_message_logs.extend(to_env) + + accuracy = ( + sum(total_rewards) / len(total_rewards) + if len(total_rewards) > 0 + else 0 + ) + avg_length = ( + sum(total_lengths) / len(total_lengths) + if len(total_lengths) > 0 + else 0 + ) + + all_val_metrics[f"{bench_name}_accuracy"] = accuracy + all_val_metrics[f"{bench_name}_avg_length"] = avg_length + + print(f"\n📊 {bench_name} Results:") + print(f" • Accuracy: {accuracy:.4f}") + print(f" • Avg response length: {avg_length:.1f} tokens") + print(f" • Samples processed: {len(total_rewards)}", flush=True) + + try: + num_to_print = master_config["logger"].get( + "num_val_samples_to_print", 3 + ) + print_message_log_samples( + all_message_logs, + total_rewards, + num_samples=min(num_to_print, len(all_message_logs)), + step=step, + ) + except Exception as e: + print(f" ⚠️ Error displaying samples: {e}", flush=True) + + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + eval_time = timing_metrics.get("total_eval_time", 0) + print(f"\n ⏱️ Total eval time: {eval_time:.2f}s", flush=True) + timer.reset() + + return all_val_metrics, timing_metrics + + +# ========================================================================= +# Eval hook factory for generation-based evaluation (MATH/MMLU) +# ========================================================================= +def make_gen_eval_hook(generation, eval_dataloaders, eval_envs, + eval_config, master_config, tokenizer, colocated_inference): + """Create a closure that wraps gen_validate for use as an eval_hook callback. + + The returned function manages vLLM weight refitting and generation lifecycle + so the shared training loop in off_policy_distillation.py doesn't need to + know about generation-based evaluation details. + """ + generation_stale = True + + def hook(step, student_policy, teacher_policy, logger): + nonlocal generation_stale + from nemo_rl.algorithms.grpo import refit_policy_generation + + if generation_stale: + refit_policy_generation(student_policy, generation, colocated_inference) + generation_stale = False + + val_metrics, val_timings = gen_validate( + generation, eval_dataloaders, eval_envs, + eval_config, master_config, step=step, tokenizer=tokenizer, + ) + generation.finish_generation() + logger.log_metrics(val_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, step, prefix="validation") + generation_stale = True + return val_metrics + + return hook + + +# ========================================================================= +# Main +# ========================================================================= +def main(): + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "llama_off_policy_arrow.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: OffPolicyMasterConfig = OmegaConf.to_container(config, resolve=True) + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"📊 Using log directory: {config['logger']['log_dir']}") + + init_ray() + + # ── Tokenizer ── + from nemo_rl.algorithms.utils import get_tokenizer + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # ── Configure generation (for eval only) ── + generation_config = config["policy"].get("generation") + if generation_config is not None: + config["policy"]["generation"] = configure_generation_config( + generation_config, tokenizer + ) + + # ── Training data (arrow) ── + train_dataset, task_spec = setup_train_data( + tokenizer, config["data"], config["distillation"]["seed"] + ) + + # ── Core setup ── + set_seed(config["distillation"]["seed"]) + + policy_config = config["policy"] + teacher_config = config["teacher"] + distillation_config = config["distillation"] + data_config = config["data"] + cluster_config = config["cluster"] + + logger = Logger(config["logger"]) + logger.log_hyperparams(config) + + checkpointer = CheckpointManager(config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + distillation_save_state: Optional[OffPolicyDistillationSaveState] = cast( + Optional[OffPolicyDistillationSaveState], + checkpointer.load_training_info(last_checkpoint_path), + ) + if distillation_save_state is None: + distillation_save_state = _default_distillation_save_state() + + # ── Dataloader ── + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=distillation_config["num_prompts_per_step"], + shuffle=data_config.get("shuffle", True), + collate_fn=rl_collate_fn, + drop_last=True, + ) + if last_checkpoint_path: + train_dataloader.load_state_dict( + torch.load(os.path.join(last_checkpoint_path, "train_dataloader.pt")) + ) + + has_generation = generation_config is not None + max_colocated = 4 if has_generation else 3 + + # ── Cluster ── + print("\n▶ Setting up compute cluster...") + cluster = RayVirtualCluster( + name="off_policy_distillation_eval_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=max_colocated, + ) + print( + f" ✓ Cluster: {cluster_config['num_nodes']} nodes, max_colocated={max_colocated}" + ) + + # ── Cross-tokenizer setup ── + token_aligner_cfg = config.get("token_aligner", {}) + cross_tokenizer_enabled = token_aligner_cfg.get("enabled", False) + token_aligner = None + teacher_tokenizer = None + + if cross_tokenizer_enabled: + from nemo_rl.algorithms.x_token import TokenAligner + + print("\n▶ Setting up cross-tokenizer distillation (TokenAligner)...") + teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_config["model_name"]) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + + token_aligner = TokenAligner( + teacher_tokenizer_name=teacher_config["model_name"], + student_tokenizer_name=policy_config["model_name"], + max_comb_len=token_aligner_cfg.get("max_comb_len", 4), + projection_matrix_multiplier=token_aligner_cfg.get( + "projection_matrix_multiplier", 1.0 + ), + ) + token_aligner._load_logits_projection_map( + file_path=token_aligner_cfg["projection_matrix_path"], + use_sparse_format=token_aligner_cfg.get("use_sparse_format", True), + learnable=token_aligner_cfg.get("learnable", False), + device="cpu", + ) + if token_aligner_cfg.get("project_teacher_to_student", False): + token_aligner.create_reverse_projection_matrix(device="cpu") + + token_aligner.precompute_canonical_maps() + + print(f" ✓ TokenAligner initialized ({policy_config['model_name']} → {teacher_config['model_name']})") + else: + # ── Vocab check (same-tokenizer mode only) ── + if not bool(os.getenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", False)): + check_vocab_equality( + tokenizer, policy_config["model_name"], teacher_config["model_name"] + ) + + # ── Teacher Policy ── + print("\n▶ Setting up teacher policy...") + if teacher_config.get("megatron_cfg", {}).get("enabled", False): + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(train_dataloader), + ) + teacher_config["megatron_cfg"]["train_iters"] = total_train_iters + + teacher_policy = Policy( + name_prefix="teacher", + cluster=cluster, + config=teacher_config, + tokenizer=teacher_tokenizer if cross_tokenizer_enabled else tokenizer, + weights_path=None, + optimizer_path=None, + init_optimizer=False, + init_reference_model=False, + ) + teacher_policy.offload_after_refit() + + # ── Student Policy ── + print("\n▶ Setting up student policy...") + weights_path = None + optimizer_path = None + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + + if policy_config.get("megatron_cfg", {}).get("enabled", False): + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(train_dataloader), + ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + student_policy = Policy( + name_prefix="student", + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + init_reference_model=False, + ) + + if cross_tokenizer_enabled: + loss_fn = CrossTokenizerDistillationLossFn(config["loss_fn"], token_aligner) + else: + loss_fn = DistillationLossFn(config["loss_fn"]) + + # ── vLLM Generation (colocated, for eval only) ── + generation: Optional[GenerationInterface] = None + if has_generation: + print("\n▶ Setting up vLLM generation (colocated, for eval)...") + gen_cfg = config["policy"]["generation"] + gen_cfg["model_name"] = policy_config["model_name"] + if "vllm_cfg" in gen_cfg: + gen_cfg["vllm_cfg"]["hf_overrides"] = policy_config.get( + "hf_config_overrides", {} + ) + + generation = VllmGeneration( + cluster=cluster, config=cast(VllmConfig, gen_cfg) + ) + generation.finish_generation() + + state_dict_info = student_policy.prepare_refit_info() + generation.prepare_refit_info(state_dict_info) + print(f" ✓ vLLM generation ready (model={policy_config['model_name']})") + + # ── Eval datasets + environments ── + eval_dataloaders: Optional[dict[str, StatefulDataLoader]] = None + eval_envs: Optional[dict[str, dict[str, EnvironmentInterface]]] = None + + eval_config = config.get("eval") + if eval_config and has_generation: + eval_dataloaders, eval_envs = setup_eval_data( + tokenizer, + eval_config, + max_seq_length=policy_config["max_total_sequence_length"], + ) + + print("\n" + "=" * 60) + print(" " * 10 + "OFF-POLICY DISTILLATION + EVAL SETUP COMPLETE") + print("=" * 60 + "\n") + + # ── Build eval hook ── + eval_hook = None + eval_hook_period = 0 + eval_hook_at_start = False + if has_generation and eval_config and eval_dataloaders and eval_envs: + colocated_inference = ( + config["policy"]["generation"]["colocated"]["enabled"] + if config["policy"].get("generation") + else True + ) + eval_hook = make_gen_eval_hook( + generation, eval_dataloaders, eval_envs, + eval_config, config, tokenizer, colocated_inference, + ) + eval_hook_period = eval_config["val_period"] + eval_hook_at_start = eval_config.get("val_at_start", False) + + # ── Train ── + off_policy_distillation_train( + student_policy=student_policy, + teacher_policy=teacher_policy, + dataloader=train_dataloader, + val_dataloader=None, + tokenizer=tokenizer, + loss_fn=loss_fn, + logger=logger, + checkpointer=checkpointer, + distillation_save_state=distillation_save_state, + master_config=config, + eval_hook=eval_hook, + eval_hook_period=eval_hook_period, + eval_hook_at_start=eval_hook_at_start, + token_aligner=token_aligner, + teacher_tokenizer=teacher_tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_off_policy_distillation_math.py b/examples/run_off_policy_distillation_math.py new file mode 100644 index 0000000000..459d200fce --- /dev/null +++ b/examples/run_off_policy_distillation_math.py @@ -0,0 +1,344 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Off-Policy Distillation Training Script (Math) + +This script runs off-policy distillation where: +- A fixed dataset of prompt-response pairs is used (no student generation) +- Teacher provides logits for the fixed responses +- Student aligns with teacher using KL divergence loss + +Key difference from on-policy distillation: +- No student generation step - uses pre-existing responses from dataset +- No environment needed for reward computation +""" + +import argparse +import os +from functools import partial +from typing import Any + +import torch +from omegaconf import OmegaConf +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from nemo_rl.algorithms.off_policy_distillation import ( + OffPolicyMasterConfig, + off_policy_distillation_train, + setup, +) +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Run off-policy distillation training with configuration" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# =============================================================================== +# Off-Policy Data Processor +# =============================================================================== +TokenizerType = PreTrainedTokenizerBase + + +def off_policy_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: PreTrainedTokenizerBase, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, +) -> DatumSpec: + """ + Process a datum dictionary for off-policy distillation. + + This processor handles datasets with prompt-response pairs where the response + is already provided. It creates message_log with token_ids and loss masks. + + Supports multiple input formats: + 1. {"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]} + 2. {"conversations": [{"from": "human/gpt", "value": "..."}]} # ShareGPT + 3. {"prompt": "...", "response": "..."} + 4. {"input": "...", "output": "..."} + 5. {"instruction": "...", "output": "..."} # Alpaca + 6. {"text": "..."} # Full text - train on all tokens (language modeling style) + """ + + # Special handling for raw text format (no chat structure) + if "text" in datum_dict and len(datum_dict.keys()) == 1: + # Raw text format - tokenize directly without chat template + # Train on all tokens (language modeling / SFT style) + text = datum_dict["text"] + + # Add BOS token if tokenizer has one and add_bos is True + if add_bos and tokenizer.bos_token: + text = tokenizer.bos_token + text + + token_ids = tokenizer( + text, + return_tensors="pt", + add_special_tokens=False, + truncation=True, + max_length=max_seq_length, + )["input_ids"][0] + + # Train on all tokens + token_loss_mask = torch.ones_like(token_ids) + + length = len(token_ids) + loss_multiplier = 1.0 if length <= max_seq_length else 0.0 + + message_log: LLMMessageLogType = [{ + "role": "assistant", # Mark as assistant so loss is computed + "content": text[:500] + "..." if len(text) > 500 else text, + "token_ids": token_ids, + "token_loss_mask": token_loss_mask, + }] + + return { + "message_log": message_log, + "length": length, + "extra_env_info": {}, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": "off_policy_distillation", + } + + # Handle chat-structured formats + messages = None + + if "messages" in datum_dict: + messages = datum_dict["messages"] + elif "conversations" in datum_dict: + # ShareGPT format + messages = [] + for conv in datum_dict["conversations"]: + role_from = conv.get("from", conv.get("role", "")) + if role_from in ["gpt", "assistant", "model", "chatbot"]: + role = "assistant" + elif role_from in ["system"]: + role = "system" + else: + role = "user" + content = conv.get("value", conv.get("content", "")) + messages.append({"role": role, "content": content}) + elif "prompt" in datum_dict and "response" in datum_dict: + messages = [ + {"role": "user", "content": datum_dict["prompt"]}, + {"role": "assistant", "content": datum_dict["response"]}, + ] + elif "input" in datum_dict and "output" in datum_dict: + messages = [ + {"role": "user", "content": datum_dict["input"]}, + {"role": "assistant", "content": datum_dict["output"]}, + ] + elif "instruction" in datum_dict: + user_content = datum_dict["instruction"] + if "input" in datum_dict and datum_dict["input"]: + user_content = f"{user_content}\n\n{datum_dict['input']}" + response = datum_dict.get("output", datum_dict.get("response", "")) + messages = [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": response}, + ] + elif "text" in datum_dict: + # Text with other keys - treat as assistant response + messages = [{"role": "assistant", "content": datum_dict["text"]}] + else: + raise ValueError( + f"Unsupported datum format. Expected: messages, conversations, " + f"prompt/response, input/output, instruction/output, or text. " + f"Got keys: {list(datum_dict.keys())}" + ) + + # Add system prompt if specified + if task_data_spec.system_prompt: + messages = [{"role": "system", "content": task_data_spec.system_prompt}] + messages + + # Build message_log with tokenization + message_log: LLMMessageLogType = [] + + for i, msg in enumerate(messages): + role = msg["role"] + content = msg["content"] + + # Apply prompt template for user messages + if role == "user" and task_data_spec.prompt: + content = task_data_spec.prompt.format(content) + + # Add generation prompt only for last user message before assistant + add_gen_prompt = ( + role == "user" + and i + 1 < len(messages) + and messages[i + 1]["role"] == "assistant" + ) + + # Tokenize + chat_msg = [{"role": role, "content": content}] + formatted = tokenizer.apply_chat_template( + chat_msg, + tokenize=False, + add_generation_prompt=add_gen_prompt, + add_special_tokens=(i == 0), + ) + + token_ids = tokenizer( + formatted, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + # Loss mask: 1 for assistant, 0 for others + if role == "assistant": + token_loss_mask = torch.ones_like(token_ids) + else: + token_loss_mask = torch.zeros_like(token_ids) + + message_log.append({ + "role": role, + "content": formatted, + "token_ids": token_ids, + "token_loss_mask": token_loss_mask, + }) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + for message in message_log: + max_per_msg = max(4, max_seq_length // len(message_log)) + message["token_ids"] = message["token_ids"][:max_per_msg] + message["token_loss_mask"] = message["token_loss_mask"][:max_per_msg] + loss_multiplier = 0.0 + + return { + "message_log": message_log, + "length": length, + "extra_env_info": datum_dict.get("extra_env_info", {}), + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict.get("task_name", "off_policy_distillation"), + } + + +def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): + """ + Setup data for off-policy distillation using load_response_dataset (like run_sft.py). + + This uses the same data loading infrastructure as SFT training. + """ + print("\n▶ Setting up data for off-policy distillation...") + + # Load dataset using the same approach as run_sft.py + data = load_response_dataset(data_config, seed) + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + task_spec = data.task_spec + print( + f" ✓ Training and validation datasets loaded with {len(train_dataset)} and {len(val_dataset)} samples, respectively." + ) + + # Use the off-policy data processor (includes token_loss_mask for distillation) + train_dataset = AllTaskProcessedDataset( + train_dataset, + tokenizer, + task_spec, + partial( + off_policy_data_processor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + ), + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset, val_dataset, task_spec + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "off_policy_distillation_math.yaml" + ) + + config = load_config(args.config) + if overrides: + config = parse_hydra_overrides(config, overrides) + + config: OffPolicyMasterConfig = OmegaConf.to_container(config, resolve=True) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + + init_ray() + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # Note: No generation config needed for off-policy distillation + # since we don't generate responses from the student + + # Setup data using load_response_dataset (like run_sft.py) + dataset, val_dataset, task_spec = setup_data( + tokenizer, config["data"], config["distillation"]["seed"] + ) + + # Setup returns fewer items than on-policy (no student_generation, no val_dataloader) + ( + student_policy, + teacher_policy, + dataloader, + loss_fn, + logger, + checkpointer, + distillation_state, + master_config, + ) = setup(config, tokenizer, dataset) + + # Off-policy training: no student_generation, no environments + off_policy_distillation_train( + student_policy, + teacher_policy, + dataloader, + tokenizer, + loss_fn, + logger, + checkpointer, + distillation_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_off_policy_distillation_v2.py b/examples/run_off_policy_distillation_v2.py new file mode 100644 index 0000000000..17bc66561a --- /dev/null +++ b/examples/run_off_policy_distillation_v2.py @@ -0,0 +1,496 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Off-policy distillation v2: single dual-model Policy with GPU-local IPC. + +Uses DTensorDistillationWorker which holds both teacher and student models +in the same Ray actor. Teacher logprobs stay on GPU via IPC buffers — +no Ray object store transfer. + +Usage: + uv run examples/run_off_policy_distillation_v2.py \ + --config examples/configs/llama_off_policy_arrow.yaml +""" + +import argparse +import os +import pprint +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Callable, Optional, cast + +import torch +from omegaconf import OmegaConf +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from nemo_rl.algorithms.loss_functions import DistillationLossFn +from nemo_rl.algorithms.off_policy_distillation import ( + OffPolicyDistillationSaveState, + OffPolicyMasterConfig, + _default_distillation_save_state, + check_vocab_equality, +) +from nemo_rl.algorithms.off_policy_distillation_v2 import ( + off_policy_distillation_train_v2, +) +from nemo_rl.algorithms.utils import set_seed +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_eval_dataset +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec +from nemo_rl.data.llm_message_utils import get_keys_from_message_log +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster, init_ray +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.experience.rollouts import run_multi_turn_rollout +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.models.generation.interfaces import GenerationInterface +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointManager +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import Logger, get_next_experiment_dir, print_message_log_samples +from nemo_rl.utils.timer import Timer + +import ray + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b, replace=True) +OmegaConf.register_new_resolver("max", lambda a, b: max(a, b), replace=True) + +DISTILLATION_WORKER_CLS = "nemo_rl.models.policy.workers.dtensor_distillation_worker.DTensorDistillationWorker" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Off-policy distillation v2 (dual-model worker, GPU-local IPC)" + ) + parser.add_argument("--config", type=str, default=None, help="Path to YAML config") + args, overrides = parser.parse_known_args() + return args, overrides + + +def _sft_preprocessor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False, + datum_preprocessor: Optional[Callable] = None, +) -> DatumSpec: + from nemo_rl.data.llm_message_utils import get_formatted_message_log + + if datum_preprocessor is not None: + datum_dict = datum_preprocessor(datum_dict) + + message_log = get_formatted_message_log( + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, + add_generation_prompt=add_generation_prompt, + tools=datum_dict.get("tools", None), + ) + + length = sum(len(m["token_ids"]) for m in message_log) + loss_multiplier = 1.0 + if length > max_seq_length: + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + return { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + + +def setup_train_data(tokenizer: AutoTokenizer, data_config: DataConfig, seed: int): + from nemo_rl.data.datasets import load_response_dataset + + print("\n Setting up training data...") + data = load_response_dataset(data_config, seed) + train_dataset_raw = data.formatted_ds["train"] + task_spec = data.task_spec + + train_dataset = AllTaskProcessedDataset( + train_dataset_raw, + tokenizer, + task_spec, + partial( + _sft_preprocessor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + add_generation_prompt=data_config.get("add_generation_prompt", False), + ), + max_seq_length=data_config["max_input_seq_length"], + ) + print(f" Training dataset loaded with {len(train_dataset)} samples") + return train_dataset, task_spec + + +def setup_eval_data( + tokenizer: AutoTokenizer, + eval_config: dict[str, Any], + max_seq_length: int, +): + print("\n Setting up evaluation benchmarks...") + eval_dataloaders: dict[str, StatefulDataLoader] = {} + eval_envs: dict[str, dict[str, EnvironmentInterface]] = {} + + for bench_name, bench_cfg in eval_config["benchmarks"].items(): + dataset_name = bench_cfg["dataset_name"] + prompt_file = bench_cfg.get("prompt_file") + system_prompt_file = bench_cfg.get("system_prompt_file") + env_cfg = bench_cfg.get("env", {"num_workers": 8}) + + data_cfg = { + "dataset_name": dataset_name, + "prompt_file": prompt_file, + "system_prompt_file": system_prompt_file, + } + base_dataset = load_eval_dataset(data_cfg) + + task_spec = TaskDataSpec( + task_name=dataset_name, + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + + dataset = AllTaskProcessedDataset( + dataset=base_dataset.rekeyed_ds, + tokenizer=tokenizer, + default_task_data_spec=task_spec, + task_data_processors=base_dataset.processor, + max_seq_length=max_seq_length, + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=eval_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + ) + + math_env = MathEnvironment.options( + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ), + "env_vars": dict(os.environ), + } + ).remote(env_cfg) + + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) + task_to_env[dataset_name] = math_env + task_to_env[None] = math_env + + eval_dataloaders[bench_name] = dataloader + eval_envs[bench_name] = task_to_env + print(f" {bench_name}: {len(dataset)} samples, env={dataset_name}") + + return eval_dataloaders, eval_envs + + +def gen_validate( + generation: GenerationInterface, + eval_dataloaders: dict[str, StatefulDataLoader], + eval_envs: dict[str, dict[str, EnvironmentInterface]], + eval_config: dict[str, Any], + master_config: dict[str, Any], + step: int, + tokenizer: PreTrainedTokenizerBase | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + timer = Timer() + all_val_metrics: dict[str, Any] = {} + + max_val_samples = eval_config.get("max_val_samples", 512) + val_batch_size = eval_config["val_batch_size"] + max_batches = max_val_samples // val_batch_size + max_rollout_turns = eval_config.get("max_rollout_turns", 1) + max_seq_len = master_config["policy"]["max_total_sequence_length"] + + with timer.time("total_eval_time"): + for bench_name, dataloader in eval_dataloaders.items(): + print(f"\n Evaluating {bench_name} at step {step}...", flush=True) + total_rewards = [] + total_lengths = [] + all_message_logs = [] + + for batch_idx, val_batch in enumerate(dataloader): + if batch_idx >= max_batches: + break + + val_batch, gen_metrics = run_multi_turn_rollout( + generation, val_batch, tokenizer, + eval_envs[bench_name], + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + greedy=True, + ) + + rewards = val_batch["total_reward"] + total_rewards.extend(rewards.tolist()) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + + to_env = [ + get_keys_from_message_log(val_batch["message_log"][i], ["role", "content"]) + for i in range(len(val_batch["message_log"])) + ] + all_message_logs.extend(to_env) + + accuracy = sum(total_rewards) / len(total_rewards) if total_rewards else 0 + avg_length = sum(total_lengths) / len(total_lengths) if total_lengths else 0 + + all_val_metrics[f"{bench_name}_accuracy"] = accuracy + all_val_metrics[f"{bench_name}_avg_length"] = avg_length + + print(f"\n {bench_name} Results:") + print(f" Accuracy: {accuracy:.4f}") + print(f" Avg response length: {avg_length:.1f} tokens") + print(f" Samples processed: {len(total_rewards)}", flush=True) + + try: + num_to_print = master_config["logger"].get("num_val_samples_to_print", 3) + print_message_log_samples( + all_message_logs, total_rewards, + num_samples=min(num_to_print, len(all_message_logs)), + step=step, + ) + except Exception as e: + print(f" Error displaying samples: {e}", flush=True) + + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + eval_time = timing_metrics.get("total_eval_time", 0) + print(f"\n Total eval time: {eval_time:.2f}s", flush=True) + timer.reset() + + return all_val_metrics, timing_metrics + + +def make_gen_eval_hook(generation, eval_dataloaders, eval_envs, + eval_config, master_config, tokenizer, colocated_inference): + generation_stale = True + + def hook(step, student_policy, teacher_policy, logger): + nonlocal generation_stale + from nemo_rl.algorithms.grpo import refit_policy_generation + + if generation_stale: + refit_policy_generation(student_policy, generation, colocated_inference) + generation_stale = False + + val_metrics, val_timings = gen_validate( + generation, eval_dataloaders, eval_envs, + eval_config, master_config, step=step, tokenizer=tokenizer, + ) + generation.finish_generation() + logger.log_metrics(val_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, step, prefix="validation") + generation_stale = True + return val_metrics + + return hook + + +def main(): + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "llama_off_policy_arrow.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: OffPolicyMasterConfig = OmegaConf.to_container(config, resolve=True) + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"Using log directory: {config['logger']['log_dir']}") + + init_ray() + + from nemo_rl.algorithms.utils import get_tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + generation_config = config["policy"].get("generation") + if generation_config is not None: + config["policy"]["generation"] = configure_generation_config( + generation_config, tokenizer + ) + + train_dataset, task_spec = setup_train_data( + tokenizer, config["data"], config["distillation"]["seed"] + ) + + set_seed(config["distillation"]["seed"]) + + policy_config = config["policy"] + teacher_config = config["teacher"] + distillation_config = config["distillation"] + data_config = config["data"] + cluster_config = config["cluster"] + + logger = Logger(config["logger"]) + logger.log_hyperparams(config) + + checkpointer = CheckpointManager(config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + distillation_save_state: Optional[OffPolicyDistillationSaveState] = cast( + Optional[OffPolicyDistillationSaveState], + checkpointer.load_training_info(last_checkpoint_path), + ) + if distillation_save_state is None: + distillation_save_state = _default_distillation_save_state() + + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=distillation_config["num_prompts_per_step"], + shuffle=data_config.get("shuffle", True), + collate_fn=rl_collate_fn, + drop_last=True, + ) + if last_checkpoint_path: + train_dataloader.load_state_dict( + torch.load(os.path.join(last_checkpoint_path, "train_dataloader.pt")) + ) + + has_generation = generation_config is not None + max_colocated = 3 if has_generation else 2 + + print("\n Setting up compute cluster...") + cluster = RayVirtualCluster( + name="off_policy_distillation_v2_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=max_colocated, + ) + print(f" Cluster: {cluster_config['num_nodes']} nodes, max_colocated={max_colocated}") + + if not bool(os.getenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", False)): + check_vocab_equality( + tokenizer, policy_config["model_name"], teacher_config["model_name"] + ) + + # Single Policy with both student + teacher models in each worker + print("\n Setting up dual-model policy (student + teacher)...") + weights_path = None + optimizer_path = None + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + + policy = Policy( + name_prefix="distillation", + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + init_reference_model=False, + worker_builder_cls_override=DISTILLATION_WORKER_CLS, + extra_worker_kwargs={"teacher_config": teacher_config}, + ) + + loss_fn = DistillationLossFn(config["loss_fn"]) + + # vLLM Generation (colocated, for eval only) + generation: Optional[GenerationInterface] = None + if has_generation: + print("\n Setting up vLLM generation (colocated, for eval)...") + gen_cfg = config["policy"]["generation"] + gen_cfg["model_name"] = policy_config["model_name"] + if "vllm_cfg" in gen_cfg: + gen_cfg["vllm_cfg"]["hf_overrides"] = policy_config.get( + "hf_config_overrides", {} + ) + + generation = VllmGeneration( + cluster=cluster, config=cast(VllmConfig, gen_cfg) + ) + generation.finish_generation() + + state_dict_info = policy.prepare_refit_info() + generation.prepare_refit_info(state_dict_info) + print(f" vLLM generation ready (model={policy_config['model_name']})") + + eval_dataloaders = None + eval_envs = None + eval_config = config.get("eval") + if eval_config and has_generation: + eval_dataloaders, eval_envs = setup_eval_data( + tokenizer, eval_config, + max_seq_length=policy_config["max_total_sequence_length"], + ) + + print("\n" + "=" * 60) + print(" " * 10 + "OFF-POLICY DISTILLATION V2 SETUP COMPLETE") + print("=" * 60 + "\n") + + # Build eval hook + eval_hook = None + eval_hook_period = 0 + eval_hook_at_start = False + if has_generation and eval_config and eval_dataloaders and eval_envs: + colocated_inference = ( + config["policy"]["generation"]["colocated"]["enabled"] + if config["policy"].get("generation") + else True + ) + eval_hook = make_gen_eval_hook( + generation, eval_dataloaders, eval_envs, + eval_config, config, tokenizer, colocated_inference, + ) + eval_hook_period = eval_config["val_period"] + eval_hook_at_start = eval_config.get("val_at_start", False) + + # Train + off_policy_distillation_train_v2( + policy=policy, + dataloader=train_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + logger=logger, + checkpointer=checkpointer, + distillation_save_state=distillation_save_state, + master_config=config, + eval_hook=eval_hook, + eval_hook_period=eval_hook_period, + eval_hook_at_start=eval_hook_at_start, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_sft_arrow_with_eval.py b/examples/run_sft_arrow_with_eval.py new file mode 100644 index 0000000000..724e0ad908 --- /dev/null +++ b/examples/run_sft_arrow_with_eval.py @@ -0,0 +1,822 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SFT training on arrow data with inline generation-based MATH/MMLU evaluation. + +This script extends the standard SFT flow with periodic generation-based +evaluation using vLLM (colocated). It does NOT modify sft.py or run_sft.py; +instead it builds its own setup / training loop on top of the existing SFT +primitives. + +Usage: + uv run examples/run_sft_arrow_with_eval.py \ + --config examples/configs/llama_sft_arrow.yaml +""" + +import argparse +import os +import pprint +import warnings +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Callable, Optional, cast + +import numpy as np +import torch +from omegaconf import OmegaConf +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from nemo_rl.algorithms.grpo import refit_policy_generation +from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.sft import SFTSaveState, _default_sft_save_state +from nemo_rl.algorithms.utils import set_seed +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_eval_dataset +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec +from nemo_rl.data.llm_message_utils import ( + add_loss_mask_to_message_log, + batched_message_log_to_flat_message, + get_keys_from_message_log, +) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster, init_ray +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.experience.rollouts import run_multi_turn_rollout +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.models.generation.interfaces import GenerationInterface +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointManager +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import Logger, get_next_experiment_dir, print_message_log_samples +from nemo_rl.utils.nsys import maybe_gpu_profile_step +from nemo_rl.utils.timer import TimeoutChecker, Timer + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b, replace=True) +OmegaConf.register_new_resolver("max", lambda a, b: max(a, b), replace=True) + + +# ========================================================================= +# CLI +# ========================================================================= +def parse_args(): + parser = argparse.ArgumentParser( + description="SFT on arrow data with inline MATH/MMLU eval" + ) + parser.add_argument("--config", type=str, default=None, help="Path to YAML config") + args, overrides = parser.parse_known_args() + return args, overrides + + +# ========================================================================= +# Arrow-text SFT data (reused from run_sft.py pattern) +# ========================================================================= +def _sft_preprocessor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + add_bos: bool = True, + add_eos: bool = True, + add_generation_prompt: bool = False, +) -> DatumSpec: + from nemo_rl.data.llm_message_utils import get_formatted_message_log + + message_log = get_formatted_message_log( + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, + add_generation_prompt=add_generation_prompt, + ) + + length = sum(len(m["token_ids"]) for m in message_log) + loss_multiplier = 1.0 + if length > max_seq_length: + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + return { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + + +def setup_train_data( + tokenizer: AutoTokenizer, data_config: DataConfig, seed: int +): + from nemo_rl.data.datasets import load_response_dataset + + print("\n▶ Setting up training data...") + data = load_response_dataset(data_config, seed) + train_dataset_raw = data.formatted_ds["train"] + sft_task_spec = data.task_spec + + train_dataset = AllTaskProcessedDataset( + train_dataset_raw, + tokenizer, + sft_task_spec, + partial( + _sft_preprocessor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + add_generation_prompt=data_config.get("add_generation_prompt", False), + ), + max_seq_length=data_config["max_input_seq_length"], + ) + print(f" ✓ Training dataset loaded with {len(train_dataset)} samples") + return train_dataset, sft_task_spec + + +# ========================================================================= +# Eval data + environments (MATH, MMLU, etc.) +# ========================================================================= +def setup_eval_data( + tokenizer: AutoTokenizer, + eval_config: dict[str, Any], + max_seq_length: int, +) -> tuple[ + dict[str, StatefulDataLoader], + dict[str, dict[str, EnvironmentInterface]], +]: + """Load eval benchmark datasets and create scoring environments. + + Returns: + eval_dataloaders: {benchmark_name: StatefulDataLoader} + eval_envs: {benchmark_name: {task_name: EnvironmentInterface}} + """ + print("\n▶ Setting up evaluation benchmarks...") + eval_dataloaders: dict[str, StatefulDataLoader] = {} + eval_envs: dict[str, dict[str, EnvironmentInterface]] = {} + + for bench_name, bench_cfg in eval_config["benchmarks"].items(): + dataset_name = bench_cfg["dataset_name"] + prompt_file = bench_cfg.get("prompt_file") + system_prompt_file = bench_cfg.get("system_prompt_file") + env_cfg = bench_cfg.get("env", {"num_workers": 8}) + + data_cfg = { + "dataset_name": dataset_name, + "prompt_file": prompt_file, + "system_prompt_file": system_prompt_file, + } + base_dataset = load_eval_dataset(data_cfg) + + task_spec = TaskDataSpec( + task_name=dataset_name, + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + + dataset = AllTaskProcessedDataset( + dataset=base_dataset.rekeyed_ds, + tokenizer=tokenizer, + default_task_data_spec=task_spec, + task_data_processors=base_dataset.processor, + max_seq_length=max_seq_length, + ) + + dataloader = StatefulDataLoader( + dataset, + batch_size=eval_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + ) + + math_env = MathEnvironment.options( + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ), + "env_vars": dict(os.environ), + } + ).remote(env_cfg) + + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env) + task_to_env[dataset_name] = math_env + task_to_env[None] = math_env + + eval_dataloaders[bench_name] = dataloader + eval_envs[bench_name] = task_to_env + print(f" ✓ {bench_name}: {len(dataset)} samples, env={dataset_name}") + + return eval_dataloaders, eval_envs + + +# ========================================================================= +# Generation-based validation (ported from distillation.py) +# ========================================================================= +def gen_validate( + generation: GenerationInterface, + eval_dataloaders: dict[str, StatefulDataLoader], + eval_envs: dict[str, dict[str, EnvironmentInterface]], + eval_config: dict[str, Any], + master_config: dict[str, Any], + step: int, + tokenizer: PreTrainedTokenizerBase | None = None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run generation-based evaluation on all configured benchmarks.""" + timer = Timer() + all_val_metrics: dict[str, Any] = {} + + max_val_samples = eval_config.get("max_val_samples", 512) + val_batch_size = eval_config["val_batch_size"] + max_batches = max_val_samples // val_batch_size + max_rollout_turns = eval_config.get("max_rollout_turns", 1) + max_seq_len = master_config["policy"]["max_total_sequence_length"] + + with timer.time("total_eval_time"): + for bench_name, dataloader in eval_dataloaders.items(): + print(f"\n▶ Evaluating {bench_name} at step {step}...", flush=True) + total_rewards = [] + total_lengths = [] + all_message_logs = [] + + for batch_idx, val_batch in enumerate(dataloader): + if batch_idx >= max_batches: + break + + val_batch, gen_metrics = run_multi_turn_rollout( + generation, + val_batch, + tokenizer, + eval_envs[bench_name], + max_seq_len=max_seq_len, + max_rollout_turns=max_rollout_turns, + greedy=True, + ) + + rewards = val_batch["total_reward"] + total_rewards.extend(rewards.tolist()) + total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"]) + + to_env = [ + get_keys_from_message_log( + val_batch["message_log"][i], ["role", "content"] + ) + for i in range(len(val_batch["message_log"])) + ] + all_message_logs.extend(to_env) + + accuracy = ( + sum(total_rewards) / len(total_rewards) + if len(total_rewards) > 0 + else 0 + ) + avg_length = ( + sum(total_lengths) / len(total_lengths) + if len(total_lengths) > 0 + else 0 + ) + + all_val_metrics[f"{bench_name}_accuracy"] = accuracy + all_val_metrics[f"{bench_name}_avg_length"] = avg_length + + print(f"\n📊 {bench_name} Results:") + print(f" • Accuracy: {accuracy:.4f}") + print(f" • Avg response length: {avg_length:.1f} tokens") + print(f" • Samples processed: {len(total_rewards)}", flush=True) + + try: + num_to_print = master_config["logger"].get( + "num_val_samples_to_print", 3 + ) + print_message_log_samples( + all_message_logs, + total_rewards, + num_samples=min(num_to_print, len(all_message_logs)), + step=step, + ) + except Exception as e: + print(f" ⚠️ Error displaying samples: {e}", flush=True) + + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + eval_time = timing_metrics.get("total_eval_time", 0) + print(f"\n ⏱️ Total eval time: {eval_time:.2f}s", flush=True) + timer.reset() + + return all_val_metrics, timing_metrics + + +# ========================================================================= +# Training loop (based on sft_train, with eval hooks) +# ========================================================================= +def sft_train_with_eval( + policy: Policy, + train_dataloader: StatefulDataLoader, + tokenizer: AutoTokenizer, + loss_fn: NLLLoss, + master_config: dict[str, Any], + logger: Logger, + sft_task_spec: TaskDataSpec, + checkpointer: CheckpointManager, + sft_save_state: Optional[SFTSaveState], + generation: Optional[GenerationInterface], + eval_dataloaders: Optional[dict[str, StatefulDataLoader]], + eval_envs: Optional[dict[str, dict[str, EnvironmentInterface]]], +) -> None: + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"]["checkpoint_must_save_by"], + fit_last_save_time=True, + ) + timeout.start_iterations() + + if sft_save_state is None: + sft_save_state = _default_sft_save_state() + current_epoch = 0 + current_step = 0 + total_steps = 0 + total_valid_tokens = 0 + else: + current_epoch = sft_save_state["epoch"] + current_step = sft_save_state["step"] + total_steps = sft_save_state["total_steps"] + total_valid_tokens = sft_save_state.get("total_valid_tokens", 0) + + sft_config = master_config["sft"] + max_num_epochs = sft_config["max_num_epochs"] + max_num_steps = sft_config["max_num_steps"] + + eval_config = master_config.get("eval") + eval_period = eval_config["val_period"] if eval_config else 0 + eval_at_start = eval_config.get("val_at_start", False) if eval_config else False + has_eval = ( + eval_period > 0 + and generation is not None + and eval_dataloaders is not None + and eval_envs is not None + ) + + colocated_inference = ( + master_config["policy"]["generation"]["colocated"]["enabled"] + if master_config["policy"].get("generation") + else True + ) + need_refit = True + generation_stale = True + + # ── optional eval at start ── + if has_eval and eval_at_start and total_steps == 0: + print("\n🔍 Running initial evaluation...", flush=True) + refit_policy_generation(policy, generation, colocated_inference) + generation_stale = False + val_metrics, val_timings = gen_validate( + generation, eval_dataloaders, eval_envs, eval_config, master_config, step=0, + tokenizer=tokenizer, + ) + generation.finish_generation() + logger.log_metrics(val_metrics, 0, prefix="validation") + logger.log_metrics(val_timings, 0, prefix="timing/validation") + + policy.prepare_for_training() + + while current_epoch < max_num_epochs and total_steps < max_num_steps: + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") + + for batch in train_dataloader: + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), max_num_steps)} {'=' * 25}" + ) + maybe_gpu_profile_step(policy, total_steps + 1) + val_metrics = None + + with timer.time("total_step_time"): + # ── data prep ── + print("▶ Preparing batch...") + with timer.time("data_processing"): + add_loss_mask_to_message_log( + batch["message_log"], roles_to_train_on=["assistant"] + ) + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } + ) + train_data.update( + cat_and_padded.get_multimodal_dict(as_tensors=False) + ) + + # ── NaN debug: check input data before training ── + if current_step < 3: + _ids = train_data["input_ids"] + _mask = train_data["token_mask"] + _smask = train_data["sample_mask"] + print( + f" [NaN debug] input_ids shape={_ids.shape}, " + f"min={_ids.min().item()}, max={_ids.max().item()}, " + f"token_mask sum={_mask.sum().item():.0f}, " + f"sample_mask sum={_smask.sum().item():.0f}/{_smask.numel()}, " + f"input_lengths range=[{train_data['input_lengths'].min().item()}, " + f"{train_data['input_lengths'].max().item()}]", + flush=True, + ) + + # ── train step ── + print("▶ Taking a training step...") + with timer.time("policy_training"): + train_results = policy.train(train_data, loss_fn) + + # ── NaN debug: check loss ── + if current_step < 3: + import numpy as _np + _loss_val = train_results["loss"].numpy() + if _np.isnan(_loss_val).any(): + print( + f" [NaN debug] Loss is NaN at step {total_steps + 1}! " + f"grad_norm={train_results['grad_norm'].numpy()}, " + f"all_mb_metrics={train_results.get('all_mb_metrics', {})}", + flush=True, + ) + + generation_stale = True + + is_last_step = total_steps + 1 >= max_num_steps or ( + current_epoch + 1 == max_num_epochs + and current_step + 1 == len(train_dataloader) + ) + + # ── generation-based eval ── + if has_eval and eval_period > 0 and (total_steps + 1) % eval_period == 0: + print( + f"\n🔍 Running generation-based eval at step {total_steps + 1}...", + flush=True, + ) + with timer.time("gen_eval"): + if generation_stale: + refit_policy_generation( + policy, generation, colocated_inference + ) + generation_stale = False + val_metrics, val_timings = gen_validate( + generation, + eval_dataloaders, + eval_envs, + eval_config, + master_config, + step=total_steps + 1, + tokenizer=tokenizer, + ) + generation.finish_generation() + logger.log_metrics( + val_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + policy.prepare_for_training() + + # ── metrics ── + metrics = { + "loss": train_results["loss"].numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + } + if "moe_metrics" in train_results: + metrics.update( + {f"moe/{k}": v for k, v in train_results["moe_metrics"].items()} + ) + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k in {"lr", "wd", "global_valid_seqs", "global_valid_toks"}: + metrics[k] = np.mean(v).item() + else: + metrics[k] = np.sum(v).item() + total_valid_tokens += metrics["global_valid_toks"] + + # ── checkpointing ── + sft_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + timeout.mark_iteration() + should_save_by_step = ( + is_last_step + or (total_steps + 1) + % master_config["checkpointing"]["save_period"] + == 0 + ) + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + sft_save_state["step"] = (current_step + 1) % len(train_dataloader) + sft_save_state["total_steps"] = total_steps + 1 + sft_save_state["epoch"] = current_epoch + sft_save_state["total_valid_tokens"] = total_valid_tokens + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: + warnings.warn( + f"Checkpointing metric {metric_name} requested but no {prefix} metrics collected.", + stacklevel=2, + ) + if full_metric_name in sft_save_state: + del sft_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + sft_save_state[full_metric_name] = metrics_source[ + metric_name + ] + + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {total_steps + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, sft_save_state, master_config + ) + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + # ── logging ── + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print("\n📊 Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + if "total_flops" in train_results: + total_tflops = ( + train_results["total_flops"] + / timing_metrics["policy_training"] + / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)" + ) + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops + + print("\n⏱️ Timing:") + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + total_num_gpus = ( + master_config["cluster"]["num_nodes"] + * master_config["cluster"]["gpus_per_node"] + ) + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics["global_valid_toks"] / total_time / total_num_gpus + ) + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + + if should_save_by_timeout: + print("Timeout reached, stopping training early", flush=True) + return + if total_steps >= max_num_steps: + print("Max steps reached, stopping training early", flush=True) + return + + current_epoch += 1 + current_step = 0 + + +# ========================================================================= +# Main +# ========================================================================= +def main(): + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "llama_sft_arrow.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config = OmegaConf.to_container(config, resolve=True) + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"📊 Using log directory: {config['logger']['log_dir']}") + + init_ray() + + # ── Tokenizer ── + from nemo_rl.algorithms.utils import get_tokenizer + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + + # ── Configure generation ── + generation_config = config["policy"].get("generation") + if generation_config is not None: + config["policy"]["generation"] = configure_generation_config( + generation_config, tokenizer + ) + + # ── Training data (arrow) ── + train_dataset, sft_task_spec = setup_train_data( + tokenizer, config["data"], config["sft"]["seed"] + ) + + # ── Core SFT setup (cluster, policy, dataloader, loss, logger, checkpointer) ── + set_seed(config["sft"]["seed"]) + + policy_config = config["policy"] + data_config = config["data"] + logger_config = config["logger"] + cluster_config = config["cluster"] + sft_config = config["sft"] + + logger = Logger(logger_config) + logger.log_hyperparams(config) + + checkpointer = CheckpointManager(config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + sft_save_state: Optional[SFTSaveState] = cast( + Optional[SFTSaveState], checkpointer.load_training_info(last_checkpoint_path) + ) + + train_dataloader = StatefulDataLoader( + train_dataset, + batch_size=policy_config["train_global_batch_size"], + shuffle=data_config.get("shuffle", True), + collate_fn=rl_collate_fn, + drop_last=True, + num_workers=data_config.get("num_workers", 1), + ) + if last_checkpoint_path is not None: + train_dataloader.load_state_dict( + torch.load(os.path.join(last_checkpoint_path, "train_dataloader.pt")) + ) + + has_generation = generation_config is not None + max_colocated = 3 if has_generation else 1 + + print("\n▶ Setting up compute cluster...") + cluster = RayVirtualCluster( + name="sft_eval_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=max_colocated, + ) + print(f" ✓ Cluster: {cluster_config['num_nodes']} nodes, max_colocated={max_colocated}") + + # ── Policy ── + print("\n▶ Setting up model...") + if policy_config.get("megatron_cfg", {}).get("enabled", False): + total_train_iters = min( + sft_config["max_num_steps"], + sft_config["max_num_epochs"] * len(train_dataloader), + ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + processor = None + if not isinstance(tokenizer, PreTrainedTokenizerBase): + processor = tokenizer + tokenizer = processor.tokenizer + + policy = Policy( + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + processor=processor, + weights_path=Path(last_checkpoint_path) / "policy" / "weights" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" + if last_checkpoint_path + else None, + init_optimizer=True, + init_reference_model=False, + ) + policy.print_node_ip_and_gpu_id() + loss_fn = NLLLoss() + print(" ✓ Model initialized") + + # ── vLLM Generation (colocated) ── + generation: Optional[GenerationInterface] = None + if has_generation: + print("\n▶ Setting up vLLM generation (colocated)...") + gen_cfg = config["policy"]["generation"] + gen_cfg["model_name"] = policy_config["model_name"] + if "vllm_cfg" in gen_cfg: + gen_cfg["vllm_cfg"]["hf_overrides"] = policy_config.get( + "hf_config_overrides", {} + ) + + generation = VllmGeneration( + cluster=cluster, config=cast(VllmConfig, gen_cfg) + ) + generation.finish_generation() + + state_dict_info = policy.prepare_refit_info() + generation.prepare_refit_info(state_dict_info) + print(f" ✓ vLLM generation ready (model={policy_config['model_name']})") + + # ── Eval datasets + environments ── + eval_dataloaders: Optional[dict[str, StatefulDataLoader]] = None + eval_envs: Optional[dict[str, dict[str, EnvironmentInterface]]] = None + + eval_config = config.get("eval") + if eval_config and has_generation: + eval_dataloaders, eval_envs = setup_eval_data( + tokenizer, + eval_config, + max_seq_length=policy_config["max_total_sequence_length"], + ) + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + + # ── Train ── + sft_train_with_eval( + policy=policy, + train_dataloader=train_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + master_config=config, + logger=logger, + sft_task_spec=sft_task_spec, + checkpointer=checkpointer, + sft_save_state=sft_save_state, + generation=generation, + eval_dataloaders=eval_dataloaders, + eval_envs=eval_envs, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/test_dataloader.py b/examples/test_dataloader.py new file mode 100644 index 0000000000..5e3f498c39 --- /dev/null +++ b/examples/test_dataloader.py @@ -0,0 +1,230 @@ +""" +Test the exact dataloader pipeline used by run_off_policy_distillation_arrow.py. + +Exercises the REAL code path without Ray, GPU, or model loading: + load_response_dataset -> AllTaskProcessedDataset -> StatefulDataLoader(rl_collate_fn) + +Usage (from the RL/ directory): + python examples/test_dataloader.py # all arrow files, batch=4 + python examples/test_dataloader.py --max-files 1 # only 1 arrow file (fast) + python examples/test_dataloader.py --batch-size 8 --num-batches 5 +""" + +import argparse +import glob +import time +from functools import partial + +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec +from nemo_rl.data.llm_message_utils import get_formatted_message_log + + +# Inlined from run_off_policy_distillation_arrow.py to avoid import issues +def sft_preprocessor(datum_dict, task_data_spec, tokenizer, max_seq_length, idx, + add_bos=True, add_eos=True, add_generation_prompt=False, + datum_preprocessor=None): + """Process a datum dictionary for off-policy distillation.""" + if datum_preprocessor is not None: + datum_dict = datum_preprocessor(datum_dict) + + message_log = get_formatted_message_log( + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, + add_generation_prompt=add_generation_prompt, + tools=datum_dict.get("tools", None), + ) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + return { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + +ARROW_GLOB = "/lustre/fsw/portfolios/llmservice/users/sdiao/data/climb_nm5.5_phase3_400b_shuffled_text_only_global_shuffle/*.arrow" +MODEL_NAME = "Qwen/Qwen3-1.7B-Base" # same as config + + +def parse_args(): + p = argparse.ArgumentParser(description="Test arrow dataloader pipeline") + p.add_argument("--arrow-glob", type=str, default=ARROW_GLOB, + help="Glob pattern for arrow files") + p.add_argument("--max-files", type=int, default=None, + help="Limit number of arrow files to load (None = all)") + p.add_argument("--model", type=str, default=MODEL_NAME, + help="Tokenizer model name") + p.add_argument("--max-seq-length", type=int, default=8192, + help="Max sequence length (matches config)") + p.add_argument("--batch-size", type=int, default=4, + help="Batch size for dataloader test") + p.add_argument("--num-batches", type=int, default=3, + help="Number of batches to iterate") + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + + # ============================================================ + # 1. Discover arrow files + # ============================================================ + all_files = sorted(glob.glob(args.arrow_glob)) + if not all_files: + raise FileNotFoundError(f"No arrow files found at: {args.arrow_glob}") + if args.max_files: + all_files = all_files[: args.max_files] + print(f"[1/5] Found {len(all_files)} arrow file(s)") + + # ============================================================ + # 2. Load tokenizer (CPU only, no model weights) + # ============================================================ + print(f"\n[2/5] Loading tokenizer: {args.model}") + tokenizer = AutoTokenizer.from_pretrained(args.model) + print(f" vocab_size={tokenizer.vocab_size}, " + f"bos={tokenizer.bos_token_id}, eos={tokenizer.eos_token_id}") + + # ============================================================ + # 3. Load dataset via the REAL pipeline (load_response_dataset) + # ============================================================ + # Build the same data_config that off_policy_distillation.yaml produces + data_config = { + "dataset_name": "arrow_text", + "arrow_files": all_files, # pass resolved list instead of glob + "val_split": 0.05, + "text_key": "text", + "max_input_seq_length": args.max_seq_length, + "prompt_file": None, + "system_prompt_file": None, + "shuffle": True, + "add_bos": True, + "add_eos": True, + "add_generation_prompt": False, + } + + print(f"\n[3/5] load_response_dataset (dataset_name='arrow_text')...") + t0 = time.time() + data = load_response_dataset(data_config, seed=args.seed) + train_raw = data.formatted_ds["train"] + val_raw = data.formatted_ds["validation"] + task_spec = data.task_spec + elapsed = time.time() - t0 + print(f" Train: {len(train_raw):,} samples") + print(f" Val: {len(val_raw):,} samples") + print(f" Loaded in {elapsed:.1f}s") + + # Quick sanity check on raw data + sample0 = train_raw[0] + print(f"\n Raw sample 0 keys: {list(sample0.keys())}") + print(f" messages[0]['role']: {sample0['messages'][0]['role']}") + print(f" messages[0]['content']: {sample0['messages'][0]['content'][:200]}...") + + # ============================================================ + # 4. Wrap with AllTaskProcessedDataset (tokenization + truncation) + # ============================================================ + print(f"\n[4/5] AllTaskProcessedDataset + sft_preprocessor (max_seq_length={args.max_seq_length})...") + train_dataset = AllTaskProcessedDataset( + train_raw, + tokenizer, + task_spec, + partial( + sft_preprocessor, + add_bos=data_config["add_bos"], + add_eos=data_config["add_eos"], + add_generation_prompt=data_config["add_generation_prompt"], + ), + max_seq_length=args.max_seq_length, + ) + print(f" Dataset length: {len(train_dataset):,}") + + # Test individual items + print("\n --- Individual sample checks ---") + n_truncated = 0 + for i in range(min(5, len(train_dataset))): + item = train_dataset[i] + roles = [m["role"] for m in item["message_log"]] + n_tokens = sum(len(m["token_ids"]) for m in item["message_log"]) + if item["loss_multiplier"] == 0.0: + n_truncated += 1 + print(f" [{i}] length={item['length']:>6}, tokens_after_trunc={n_tokens:>6}, " + f"loss_mult={item['loss_multiplier']:.1f}, roles={roles}") + if n_truncated: + print(f" WARNING: {n_truncated}/5 samples were truncated (loss_multiplier=0.0). " + f"Consider increasing --max-seq-length (currently {args.max_seq_length}).") + + # ============================================================ + # 5. StatefulDataLoader + rl_collate_fn (the real dataloader) + # ============================================================ + print(f"\n[5/5] StatefulDataLoader (batch_size={args.batch_size}, " + f"collate_fn=rl_collate_fn, drop_last=True)") + dataloader = StatefulDataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + collate_fn=rl_collate_fn, + drop_last=True, + ) + print(f" Total batches: {len(dataloader):,}") + + print(f"\n --- Iterating {args.num_batches} batch(es) ---") + for bi, batch in enumerate(dataloader): + if bi >= args.num_batches: + break + + lengths = batch["length"].tolist() + loss_mults = batch["loss_multiplier"].tolist() + max_len = batch["batch_max_length"][0].item() + + print(f"\n Batch {bi}:") + print(f" keys: {sorted(batch.keys())}") + print(f" num_samples: {len(batch['message_log'])}") + print(f" lengths: {lengths}") + print(f" loss_multipliers: {loss_mults}") + print(f" batch_max_length: {max_len}") + + # Spot-check first sample + msg_log_0 = batch["message_log"][0] + tok0 = msg_log_0[0]["token_ids"] + print(f" sample[0] roles: {[m['role'] for m in msg_log_0]}") + print(f" sample[0] tok[:10]: {tok0[:10].tolist()}") + decoded = tokenizer.decode(tok0[:30], skip_special_tokens=False) + print(f" sample[0] decoded[:100]: {decoded[:100]}") + + # ============================================================ + # Summary + # ============================================================ + print("\n" + "=" * 60) + print("DATALOADER TEST COMPLETE") + print("=" * 60) + print(f" Arrow files: {len(all_files)}") + print(f" Train samples: {len(train_raw):,}") + print(f" Val samples: {len(val_raw):,}") + print(f" Tokenizer: {args.model}") + print(f" Max seq length: {args.max_seq_length}") + print(f" Batch size: {args.batch_size}") + print(f" Batches iterated: {min(args.num_batches, len(dataloader))}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/monitor_gpu.sh b/monitor_gpu.sh new file mode 100755 index 0000000000..8609556855 --- /dev/null +++ b/monitor_gpu.sh @@ -0,0 +1,264 @@ +#!/bin/bash +# Monitor GPU utilization across all nodes in a SLURM job allocation. +# Queries each node via srun --overlap --jobid. +# +# Usage: +# bash monitor_gpu.sh [JOB_ID] # one-shot snapshot +# bash monitor_gpu.sh [JOB_ID] -w 10 # refresh every 10s +# bash monitor_gpu.sh [JOB_ID] -w 10 -v # with per-GPU process info + +set -uo pipefail + +RED='\033[0;31m' +YELLOW='\033[0;33m' +GREEN='\033[0;32m' +CYAN='\033[0;36m' +BOLD='\033[1m' +DIM='\033[2m' +NC='\033[0m' + +usage() { + cat </dev/null | head -1 | tr -d ' ') + if [[ -z "$JOB_ID" ]]; then + echo "Error: No running SLURM jobs found. Provide a job ID explicitly." + exit 1 + fi + echo -e "${DIM}Auto-detected running job: ${JOB_ID}${NC}" +fi + +# ── Extract job metadata ───────────────────────────────────────────── +JOB_INFO=$(scontrol show job "$JOB_ID" 2>&1) +if echo "$JOB_INFO" | grep -q "Invalid job id"; then + echo "Error: Job $JOB_ID not found." + exit 1 +fi + +JOB_STATE=$(echo "$JOB_INFO" | grep -oP 'JobState=\K\S+') +if [[ "$JOB_STATE" != "RUNNING" ]]; then + echo "Error: Job $JOB_ID is not running (state: $JOB_STATE)." + exit 1 +fi + +JOB_NAME=$(echo "$JOB_INFO" | grep -oP 'JobName=\K\S+') +NUM_NODES=$(echo "$JOB_INFO" | grep -oP 'NumNodes=\K\d+') +NODE_LIST=$(echo "$JOB_INFO" | grep -oP '(?/dev/null || true" + +# ── Helpers ─────────────────────────────────────────────────────────── +colorize_util() { + local val=$1 + if [[ $val -lt 30 ]]; then + printf "${GREEN}%3d%%${NC}" "$val" + elif [[ $val -lt 70 ]]; then + printf "${YELLOW}%3d%%${NC}" "$val" + else + printf "${RED}%3d%%${NC}" "$val" + fi +} + +run_on_node() { + local idx=$1 node=$2 outfile=$3 + + local cmd="$QUERY_CSV" + if [[ $VERBOSE -eq 1 ]]; then + cmd="${cmd}; echo '---PROCS---'; ${QUERY_PROCS}" + fi + + srun --overlap --jobid "$JOB_ID" \ + --nodes=1 --ntasks=1 -w "$node" \ + bash -c "$cmd" \ + > "$outfile" 2>&1 +} + +# ── Main query + display ───────────────────────────────────────────── +query_all_nodes() { + local tmpdir + tmpdir=$(mktemp -d) + local pids=() + + for ((i = 0; i < ${#NODES_ARRAY[@]}; i++)); do + run_on_node "$i" "${NODES_ARRAY[$i]}" "$tmpdir/$i" & + pids+=($!) + done + + for pid in "${pids[@]}"; do + wait "$pid" 2>/dev/null || true + done + + # Accumulators + local total_gpus=0 total_util=0 total_mem_used=0 total_mem_total=0 idle_gpus=0 + local timestamp + timestamp=$(date '+%Y-%m-%d %H:%M:%S') + + echo "" + echo -e "${BOLD}GPU Utilization — Job ${JOB_ID} (${JOB_NAME}) — ${timestamp}${NC}" + echo -e "${DIM}Nodes: ${NUM_NODES} | Node list: ${NODE_LIST}${NC}" + echo "" + + local hdr + hdr=$(printf "${BOLD}%-5s %-20s %3s %-22s %5s %-22s %5s %6s${NC}" \ + "Node" "Hostname" "GPU" "Model" "Util" "Memory" "Temp" "Power") + echo -e "$hdr" + printf '%.0s─' {1..94}; echo "" + + for ((i = 0; i < ${#NODES_ARRAY[@]}; i++)); do + local node="${NODES_ARRAY[$i]}" + local outfile="$tmpdir/$i" + + if [[ ! -s "$outfile" ]]; then + printf "%-5s %-20s ${RED}%s${NC}\n" "$i" "$node" "[ERROR] No response" + continue + fi + + # Separate GPU CSV from process info (if verbose), filtering noise lines + local gpu_csv proc_csv="" + if [[ $VERBOSE -eq 1 ]] && grep -q -- '---PROCS---' "$outfile" 2>/dev/null; then + gpu_csv=$(sed '/---PROCS---/,$d' "$outfile" | grep ',' || true) + proc_csv=$(sed '1,/---PROCS---/d' "$outfile") + else + gpu_csv=$(grep ',' "$outfile" || true) + fi + + # Check for error in output + if ! echo "$gpu_csv" | grep -q ","; then + printf "%-5s %-20s ${RED}%s${NC}\n" "$i" "$node" "[ERROR] $(head -1 "$outfile" | cut -c1-60)" + continue + fi + + local first_line=1 + while IFS=',' read -r gpu_idx gpu_name gpu_util mem_used mem_total temp power; do + gpu_idx=$(echo "$gpu_idx" | xargs) + gpu_name=$(echo "$gpu_name" | xargs) + gpu_util=$(echo "$gpu_util" | xargs) + mem_used=$(echo "$mem_used" | xargs) + mem_total=$(echo "$mem_total" | xargs) + temp=$(echo "$temp" | xargs) + power=$(echo "$power" | xargs) + + if ! [[ "$gpu_util" =~ ^[0-9]+$ ]]; then + continue + fi + + total_gpus=$((total_gpus + 1)) + total_util=$((total_util + gpu_util)) + total_mem_used=$((total_mem_used + mem_used)) + total_mem_total=$((total_mem_total + mem_total)) + if [[ $gpu_util -eq 0 ]]; then + idle_gpus=$((idle_gpus + 1)) + fi + + local util_colored + util_colored=$(colorize_util "$gpu_util") + + local mem_pct=0 + if [[ $mem_total -gt 0 ]]; then + mem_pct=$((mem_used * 100 / mem_total)) + fi + local mem_str + mem_str=$(printf "%d/%d MiB (%d%%)" "$mem_used" "$mem_total" "$mem_pct") + + local node_label="" hostname_label="" + if [[ $first_line -eq 1 ]]; then + node_label="$i" + hostname_label="$node" + first_line=0 + fi + + printf "%-5s %-20s %3s %-22s %b %-22s %4sC %5sW\n" \ + "$node_label" "$hostname_label" "$gpu_idx" "$gpu_name" "$util_colored" "$mem_str" "$temp" "$power" + done <<< "$gpu_csv" + + # Verbose: show processes under this node + if [[ $VERBOSE -eq 1 ]] && [[ -n "$proc_csv" ]]; then + while IFS=',' read -r _uuid pid pname pmem; do + pid=$(echo "$pid" | xargs) + pname=$(echo "$pname" | xargs) + pmem=$(echo "$pmem" | xargs) + if [[ -n "$pid" ]] && [[ "$pid" != "0" ]]; then + printf "${DIM} └─ PID %-8s %-30s %s MiB${NC}\n" "$pid" "$pname" "$pmem" + fi + done <<< "$proc_csv" + fi + + printf "${DIM}%.0s·${NC}" {1..94}; echo "" + done + + # ── Aggregate summary ────────────────────────────────────────────── + echo "" + printf '%.0s═' {1..94}; echo "" + if [[ $total_gpus -gt 0 ]]; then + local avg_util=$((total_util / total_gpus)) + local avg_util_colored + avg_util_colored=$(colorize_util "$avg_util") + local active_gpus=$((total_gpus - idle_gpus)) + local total_mem_gib=$((total_mem_total / 1024)) + local used_mem_gib=$((total_mem_used / 1024)) + + echo -e "${BOLD}Summary${NC}" + echo -e " Total GPUs: ${CYAN}${total_gpus}${NC}" + echo -e " Active GPUs: ${CYAN}${active_gpus}${NC} | Idle GPUs: ${CYAN}${idle_gpus}${NC}" + printf " Avg Util: %b\n" "$avg_util_colored" + echo -e " Memory: ${CYAN}${used_mem_gib} / ${total_mem_gib} GiB${NC} (${total_mem_used} / ${total_mem_total} MiB)" + else + echo -e "${RED}No GPU data collected from any node.${NC}" + fi + echo "" + + rm -rf "$tmpdir" +} + +# ── Entry point ─────────────────────────────────────────────────────── +if [[ -n "$WATCH_INTERVAL" ]]; then + while true; do + clear + query_all_nodes + echo -e "${DIM}Refreshing every ${WATCH_INTERVAL}s — Ctrl+C to stop${NC}" + sleep "$WATCH_INTERVAL" + done +else + query_all_nodes +fi diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index df9b464ff3..9f5799e3e0 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -13,7 +13,13 @@ # limitations under the License. import os import warnings -from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast +from pathlib import Path +import sys +if sys.version_info >= (3, 11): + from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast +else: + from typing import Any, Optional, TypedDict, TypeVar, cast + from typing_extensions import NotRequired import numpy as np import ray @@ -23,7 +29,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation -from nemo_rl.algorithms.loss import ( +from nemo_rl.algorithms.loss_functions import ( DistillationLossConfig, DistillationLossDataDict, DistillationLossFn, @@ -79,9 +85,6 @@ class DistillationConfig(TypedDict): val_batch_size: int val_period: int val_at_start: bool - # Whether to run validation on the last training step. Setting this to True ensures the - # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). - val_at_end: bool max_val_samples: int topk_logits_k: int seed: int @@ -259,11 +262,7 @@ def setup( # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None # If validation is enabled, load the validation dataloader - if ( - distillation_config["val_period"] > 0 - or distillation_config["val_at_start"] - or distillation_config["val_at_end"] - ): + if distillation_config["val_period"] > 0 or distillation_config["val_at_start"]: assert val_dataset is not None, ( "Validation dataset is required if validation is enabled" ) @@ -412,7 +411,7 @@ def setup( generation_config = cast(VllmConfig, generation_config) if "vllm_cfg" in generation_config: ## make vllm hf overrides match the training policy - generation_config["vllm_kwargs"]["hf_overrides"] = policy_config.get( + generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( "hf_config_overrides", {} ) student_generation = VllmGeneration( @@ -430,7 +429,12 @@ def setup( print("\n▶ Setting up student policy...", flush=True) # Checkpoint paths - weights_path, optimizer_path = checkpointer.get_resume_paths(last_checkpoint_path) + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + else: + weights_path = None + optimizer_path = None if "megatron_cfg" in policy_config and policy_config["megatron_cfg"]["enabled"]: ## NOTE: this is equal to the total number of scheduler steps @@ -540,7 +544,6 @@ def distillation_train( total_valid_tokens = distillation_save_state["total_valid_tokens"] val_period = master_config["distillation"]["val_period"] val_at_start = master_config["distillation"]["val_at_start"] - val_at_end = master_config["distillation"]["val_at_end"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] max_epochs = master_config["distillation"][ "max_num_epochs" @@ -697,12 +700,13 @@ def distillation_train( print("▶ Computing teacher logprobs...", flush=True) with timer.time("teacher_logprob_inference"): teacher_topk = teacher_policy.get_topk_logits( - train_data, - k=master_config["distillation"]["topk_logits_k"], - timer=timer, + train_data, k=master_config["distillation"]["topk_logits_k"] ) - train_data["teacher_topk_logits"] = teacher_topk["topk_logits"] - train_data["teacher_topk_indices"] = teacher_topk["topk_indices"] + if isinstance(teacher_topk, list): + train_data["teacher_topk_ipc_handles"] = teacher_topk + else: + train_data["teacher_topk_logits"] = teacher_topk["topk_logits"] + train_data["teacher_topk_indices"] = teacher_topk["topk_indices"] print("▶ Preparing for training...", flush=True) with timer.time("training_prep"): @@ -712,21 +716,15 @@ def distillation_train( print("▶ Training policy...", flush=True) with timer.time("policy_training"): - train_results = student_policy.train( - train_data, - loss_fn, - timer=timer, - ) + train_results = student_policy.train(train_data, loss_fn) is_last_step = (total_steps + 1 >= max_steps) or ( (current_epoch + 1 == max_epochs) and (current_step + 1 == len(dataloader)) ) - # Run validation if it's a validation step or last step with val_at_end - if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( - val_at_end and is_last_step - ): + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( student_policy, student_generation, colocated_inference @@ -844,9 +842,7 @@ def distillation_train( ), optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" - ) - if checkpointer.save_optimizer - else None, + ), tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), @@ -945,6 +941,7 @@ def distillation_train( current_step = 0 # Reset step counter for new epoch + def validate( policy_generation: GenerationInterface, val_dataloader: Optional[StatefulDataLoader], diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 85953eb0ce..ae3be0b32f 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -17,7 +17,13 @@ import warnings from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext -from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast +from pathlib import Path +import sys +if sys.version_info >= (3, 11): + from typing import Any, NotRequired, Optional, TypedDict, TypeVar, cast +else: + from typing import Any, Optional, TypedDict, TypeVar, cast + from typing_extensions import NotRequired import numpy as np import ray @@ -26,38 +32,30 @@ from transformers import AutoProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from nemo_rl.algorithms.advantage_estimator import ( - GDPOAdvantageEstimator, - GRPOAdvantageEstimator, - ReinforcePlusPlusAdvantageEstimator, -) -from nemo_rl.algorithms.loss import ( +from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss_functions import ( ClippedPGLossConfig, ClippedPGLossDataDict, ClippedPGLossFn, ) -from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.reward_functions import ( RewardShapingConfig, apply_reward_shaping, ) from nemo_rl.algorithms.utils import ( calculate_baseline_and_std_per_prompt, - get_gdpo_reward_component_keys, log_generation_metrics_to_wandb, print_performance_metrics, set_seed, ) from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn -from nemo_rl.data.dataloader import MultipleDataloaderWrapper from nemo_rl.data.datasets import AllTaskProcessedDataset from nemo_rl.data.interfaces import DatumSpec from nemo_rl.data.llm_message_utils import ( batched_message_log_to_flat_message, get_keys_from_message_log, ) -from nemo_rl.data.utils import extract_necessary_env_names from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster @@ -68,7 +66,6 @@ run_multi_turn_rollout, ) from nemo_rl.models.generation.interfaces import GenerationInterface -from nemo_rl.models.generation.sglang import SGLangConfig, SGLangGeneration from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface @@ -79,7 +76,6 @@ LoggerConfig, print_message_log_samples, ) -from nemo_rl.utils.memory_tracker import MemoryTracker from nemo_rl.utils.nsys import maybe_gpu_profile_step from nemo_rl.utils.timer import TimeoutChecker, Timer from nemo_rl.utils.venvs import create_local_venv_on_each_node @@ -121,17 +117,6 @@ class AsyncGRPOConfig(TypedDict): recompute_kv_cache_after_weight_updates: NotRequired[bool] -class AdvEstimatorConfig(TypedDict): - """Configuration for advantage estimator (GRPO, GDPO, or Reinforce++).""" - - name: str # "grpo", "gdpo", or "reinforce_plus_plus" - # GRPO specific - normalize_rewards: NotRequired[bool] - use_leave_one_out_baseline: NotRequired[bool] - # Reinforce++ specific - minus_baseline: NotRequired[bool] - - class GRPOConfig(TypedDict): num_prompts_per_step: int num_generations_per_prompt: int @@ -143,11 +128,7 @@ class GRPOConfig(TypedDict): val_period: int val_batch_size: int val_at_start: bool - # Whether to run validation on the last training step. Setting this to True ensures the - # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). - val_at_end: bool max_val_samples: int - skip_reference_policy_logprobs_calculation: NotRequired[bool] seed: int async_grpo: NotRequired[AsyncGRPOConfig] overlong_filtering: NotRequired[bool] @@ -162,13 +143,6 @@ class GRPOConfig(TypedDict): batch_multiplier: NotRequired[float] reward_shaping: RewardShapingConfig reward_scaling: RewardScalingConfig - # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation. - calculate_advantages_on_gpu: NotRequired[bool] - # Sequence-level logprob error masking for training stability. If set, mask sequences with mult_prob_error exceeding this threshold (same scale as token_mult_prob_error metric, e.g., 1.5) - # Note that this is slightly different than Masked Importance Sampling (MIS) because this uses the absolute value of the difference between the training and generation logprobs, whereas MIS just uses the difference between the training and generation logprobs. - seq_logprob_error_threshold: float | None - # Advantage estimator configuration (grpo or reinforce_plus_plus) - adv_estimator: NotRequired[AdvEstimatorConfig] class GRPOSaveState(TypedDict): @@ -216,14 +190,14 @@ class MasterConfig(TypedDict): def setup( master_config: MasterConfig, tokenizer: TokenizerType, - dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset], + dataset: AllTaskProcessedDataset, val_dataset: Optional[AllTaskProcessedDataset], processor: Optional[AutoProcessor] = None, ) -> tuple[ ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], - StatefulDataLoader | MultipleDataloaderWrapper, + StatefulDataLoader, Optional[StatefulDataLoader], ClippedPGLossFn, Logger, @@ -276,87 +250,36 @@ def setup( # ========================== # Data # ========================== - # num_prompts_per_step and dataloader_batch_size will be different when using multiple dataloaders - num_prompts_per_step = grpo_config["num_prompts_per_step"] - if data_config["use_multiple_dataloader"]: - dataloader_batch_size = data_config["num_prompts_per_dataloader"] - else: - dataloader_batch_size = num_prompts_per_step - # Validate batch_multiplier batch_multiplier = grpo_config["batch_multiplier"] - if grpo_config["use_dynamic_sampling"]: - num_prompts_per_step = int(num_prompts_per_step * batch_multiplier) - dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) - else: + dataloader_batch_size = grpo_config["num_prompts_per_step"] + if not grpo_config["use_dynamic_sampling"]: assert batch_multiplier == 1, ( "batch_multiplier>1 can only be used if use_dynamic_sampling=True" ) + else: + dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) - # Validate number of prompts per step - if data_config["use_multiple_dataloader"]: - assert num_prompts_per_step % dataloader_batch_size == 0, ( - "Expected num_prompts_per_step to be a multiple of num_prompts_per_dataloader, " - f"but got {num_prompts_per_step} and {dataloader_batch_size}. " - "Please check the configuration of num_prompts_per_step and num_prompts_per_dataloader. " - "If use_dynamic_sampling is enabled and batch_multiplier is used, please also check the configuration of batch_multiplier." - ) - - # Load train dataset - def init_train_dataloader(dataset, suffix: str = ""): - dataloader = StatefulDataLoader( - dataset, - batch_size=dataloader_batch_size, - shuffle=data_config["shuffle"], - collate_fn=rl_collate_fn, - drop_last=True, - num_workers=data_config["num_workers"], - ) - if last_checkpoint_path is not None: - dataloader_state_dict = torch.load( - os.path.join(last_checkpoint_path, f"train_dataloader{suffix}.pt") - ) - dataloader.load_state_dict(dataloader_state_dict) - return dataloader - - if data_config["use_multiple_dataloader"]: - # Initialize dataloaders - dataloaders = {} - for task_name, task_dataset in dataset.items(): - dataloaders[task_name] = init_train_dataloader( - task_dataset, f"_{task_name}" - ) - print( - f" ✓ Training dataloader {task_name} loaded with {len(task_dataset)} samples", - flush=True, - ) - - train_sample_count = sum( - len(task_dataloader) for task_dataloader in dataloaders.values() + dataloader = StatefulDataLoader( + dataset, + batch_size=dataloader_batch_size, + shuffle=data_config["shuffle"], + collate_fn=rl_collate_fn, + drop_last=True, + num_workers=data_config["num_workers"], + ) + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") ) + dataloader.load_state_dict(dataloader_state_dict) - # Wrap dataloader - dataloader = MultipleDataloaderWrapper( - expected_num_prompts=num_prompts_per_step, - data_config=data_config, - dataloaders=dataloaders, - ) - else: - dataloader = init_train_dataloader(dataset) - train_sample_count = len(dataloader) - print( - f" ✓ Training dataloader loaded with {train_sample_count} samples", - flush=True, - ) + print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True) # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None # If validation is enabled, load the validation dataloader - if ( - grpo_config["val_period"] > 0 - or grpo_config["val_at_start"] - or grpo_config["val_at_end"] - ): + if grpo_config["val_period"] > 0 or grpo_config["val_at_start"]: assert val_dataset is not None, ( "Validation dataset is required if validation is enabled" ) @@ -372,34 +295,17 @@ def init_train_dataloader(dataset, suffix: str = ""): flush=True, ) - # ========================== - # Loss Function - # ========================== - loss_fn = ClippedPGLossFn(loss_config) - - # Validate force_on_policy_ratio - if loss_config.get("force_on_policy_ratio", False): - assert ( - grpo_config["num_prompts_per_step"] - * grpo_config["num_generations_per_prompt"] - == policy_config["train_global_batch_size"] - ), ( - "force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt" - ) - os.environ["NRL_IGNORE_TP_ACCURACY_CHECK"] = "1" - print(" ✓ force_on_policy_ratio enabled") - # ========================== # Cluster # ========================== print("\n▶ Setting up compute cluster...", flush=True) colocated_inference = generation_config["colocated"]["enabled"] - - env_name_list = extract_necessary_env_names(data_config) - rm_env_enabled = "reward_model" in env_name_list + reward_model_enabled = ( + "env_name" in data_config and data_config["env_name"] == "reward_model" + ) total_nodes = cluster_config["num_nodes"] - if rm_env_enabled: + if reward_model_enabled: rm_resource = env_configs["reward_model"]["resources"] rm_nodes = rm_resource["num_nodes"] rm_gpus_per_node = rm_resource["gpus_per_node"] @@ -476,7 +382,7 @@ def init_train_dataloader(dataset, suffix: str = ""): inference_nodes = 1 # If total_nodes == 1, reward model is also on the same node; otherwise it's on a different node reward_gpus_to_subtract = ( - rm_gpus_per_node if total_nodes == 1 and rm_env_enabled else 0 + rm_gpus_per_node if total_nodes == 1 and reward_model_enabled else 0 ) train_gpus_per_node -= inference_gpus_per_node + reward_gpus_to_subtract assert train_gpus_per_node > 0, ( @@ -484,7 +390,7 @@ def init_train_dataloader(dataset, suffix: str = ""): f"train_gpus_per_node:{train_gpus_per_node} = cluster_config['gpus_per_node']:{cluster_config['gpus_per_node']} - inference_gpus_per_node:{inference_gpus_per_node}" + ( f" - rm_gpus_per_node:{rm_gpus_per_node}" - if total_nodes == 1 and rm_env_enabled + if total_nodes == 1 and reward_model_enabled else "" ) ) @@ -543,13 +449,19 @@ def init_train_dataloader(dataset, suffix: str = ""): # Dictionary to store worker initialization timing stats for logging worker_init_timing_metrics = {} - weights_path, optimizer_path = checkpointer.get_resume_paths(last_checkpoint_path) + # Prepare checkpoint paths + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + else: + weights_path = None + optimizer_path = None if policy_config.get("megatron_cfg", {}).get("enabled", False): ## NOTE: this is equal to the total number of scheduler steps total_train_iters = min( grpo_config["max_num_steps"], - grpo_config["max_num_epochs"] * train_sample_count, + grpo_config["max_num_epochs"] * len(dataloader), ) policy_config["megatron_cfg"]["train_iters"] = total_train_iters @@ -575,77 +487,9 @@ def init_vllm(): pg.finish_generation() return pg, time.perf_counter() - t0 - def init_sglang(): - """Initialize SGLang generation workers.""" - t0 = time.perf_counter() - pg = SGLangGeneration(cluster=inference_cluster, config=generation_config) - pg.finish_generation() - return pg, time.perf_counter() - t0 - - def initialize_generation_with_policy( - init_generation_fn, - generation_name: str, - init_time_key: str, - colocated_inference: bool, - worker_init_timing_metrics: dict, - ): - """Generic function to initialize a generation engine (vLLM or SGLang) along with policy. - - Args: - init_generation_fn: Function that initializes the generation engine (init_vllm or init_sglang) - generation_name: Name of the generation engine ("vLLM" or "SGLang") - init_time_key: Key name for storing initialization time in metrics ("vllm_init_time_s" or "sglang_init_time_s") - colocated_inference: Whether inference is colocated with training - worker_init_timing_metrics: Dictionary to store timing metrics - - Returns: - Tuple of (policy_generation, policy) - """ - # Determine if parallel initialization is possible (non-colocated mode) - use_parallel_init = not colocated_inference - - if use_parallel_init: - # Parallel initialization: Generation engine and Policy can initialize simultaneously - print( - " ⚡ Using parallel worker initialization (non-colocated mode)", - flush=True, - ) - - # Execute both initializations in parallel - parallel_start_time = time.perf_counter() - with ThreadPoolExecutor(max_workers=2) as executor: - generation_future = executor.submit(init_generation_fn) - policy_future = executor.submit(init_policy) - policy_generation, generation_time = generation_future.result() - policy, policy_time = policy_future.result() - parallel_wall_time = time.perf_counter() - parallel_start_time - - # Store timing metrics - worker_init_timing_metrics[init_time_key] = generation_time - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time - worker_init_timing_metrics["parallel_init_enabled"] = True - - else: - # Sequential initialization: colocated mode (GPU memory requires generation engine first) - print( - " ⚙️ Using sequential worker initialization (colocated mode)", - flush=True, - ) - - # Initialize generation engine first (clean GPU memory), then policy - policy_generation, generation_time = init_generation_fn() - worker_init_timing_metrics[init_time_key] = generation_time - - policy, policy_time = init_policy() - worker_init_timing_metrics["policy_init_time_s"] = policy_time - worker_init_timing_metrics["parallel_init_enabled"] = 0.0 - - return policy_generation, policy - - # Handle generation-specific setup + # Handle backend-specific setup if backend == "megatron": - # Megatron generation: policy_generation is None, only initialize policy + # Megatron backend: policy_generation is None, only initialize policy policy_generation = None print( f" ✓ Using {backend} backend for generation with {policy_config['model_name']}", @@ -656,7 +500,7 @@ def initialize_generation_with_policy( worker_init_timing_metrics["policy_init_time_s"] = policy_time elif backend == "vllm": - # vLLM generation: setup config, then initialize with policy + # vLLM backend: setup config, then decide parallel vs sequential init generation_config = cast(VllmConfig, generation_config) if generation_config["vllm_cfg"]["precision"] == "fp8": assert loss_config["use_importance_sampling_correction"] is True, ( @@ -680,40 +524,52 @@ def initialize_generation_with_policy( ) ## make vllm hf overrides match the training policy - generation_config["vllm_kwargs"]["hf_overrides"] = policy_config.get( + generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( "hf_config_overrides", {} ) - policy_generation, policy = initialize_generation_with_policy( - init_generation_fn=init_vllm, - generation_name="vLLM", - init_time_key="vllm_init_time_s", - colocated_inference=colocated_inference, - worker_init_timing_metrics=worker_init_timing_metrics, - ) + # Determine if parallel initialization is possible (non-colocated mode) + use_parallel_init = not colocated_inference - print( - f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", - flush=True, - ) + if use_parallel_init: + # Parallel initialization: vLLM and Policy can initialize simultaneously + print( + " ⚡ Using parallel worker initialization (non-colocated mode)", + flush=True, + ) - elif backend == "sglang": - generation_config = cast(SGLangConfig, generation_config) + # Execute both initializations in parallel + parallel_start_time = time.perf_counter() + with ThreadPoolExecutor(max_workers=2) as executor: + vllm_future = executor.submit(init_vllm) + policy_future = executor.submit(init_policy) + policy_generation, vllm_time = vllm_future.result() + policy, policy_time = policy_future.result() + parallel_wall_time = time.perf_counter() - parallel_start_time - # Set model_path if not already set - if "model_path" not in generation_config["sglang_cfg"]: - generation_config["sglang_cfg"]["model_path"] = policy_config["model_name"] + # Store timing metrics + worker_init_timing_metrics["vllm_init_time_s"] = vllm_time + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_wall_time_s"] = parallel_wall_time + worker_init_timing_metrics["parallel_init_enabled"] = True - policy_generation, policy = initialize_generation_with_policy( - init_generation_fn=init_sglang, - generation_name="SGLang", - init_time_key="sglang_init_time_s", - colocated_inference=colocated_inference, - worker_init_timing_metrics=worker_init_timing_metrics, - ) + else: + # Sequential initialization: colocated mode (GPU memory requires vLLM first) + print( + " ⚙️ Using sequential worker initialization (colocated mode)", + flush=True, + ) + + # Initialize vLLM first (clean GPU memory), then policy + policy_generation, vllm_time = init_vllm() + worker_init_timing_metrics["vllm_init_time_s"] = vllm_time + + policy, policy_time = init_policy() + worker_init_timing_metrics["policy_init_time_s"] = policy_time + worker_init_timing_metrics["parallel_init_enabled"] = 0.0 print( - f" ✓ Using SGLang backend for generation with {policy_config['model_name']}", + f" ✓ Using vLLM backend for generation with {policy_config['model_name']}", flush=True, ) @@ -748,6 +604,19 @@ def initialize_generation_with_policy( if policy_generation is not None: policy_generation.prepare_refit_info(state_dict_info) + loss_fn = ClippedPGLossFn(loss_config) + + # Validate force_on_policy_ratio + if loss_config.get("force_on_policy_ratio", False): + assert ( + grpo_config["num_prompts_per_step"] + * grpo_config["num_generations_per_prompt"] + == policy_config["train_global_batch_size"] + ), ( + "force_on_policy_ratio requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt" + ) + print(" ✓ force_on_policy_ratio enabled") + # Calculate total setup time total_setup_time = time.perf_counter() - setup_start_time worker_init_timing_metrics["total_setup_time_s"] = total_setup_time @@ -800,6 +669,33 @@ def initialize_generation_with_policy( # =============================================================================== +def normalize_advantages_with_epsilon( + advantages: torch.Tensor, + std: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Normalize advantages by standard deviation, skipping samples with zero std. + + When std is exactly zero (from leave-one-out baseline with identical rewards), + normalization is skipped for those samples to prevent numerical instability. + This makes normalize_rewards compatible with use_leave_one_out_baseline. + + Args: + advantages: Tensor of shape (batch_size, 1) containing advantage values + std: Tensor of shape (batch_size,) containing standard deviation values + epsilon: Small value to avoid division by very small std, defaults to 1e-6 + + Returns: + Normalized advantages tensor of same shape as input advantages + """ + # Only normalize where std > 0 to avoid division by near-zero + non_zero_std_mask = std > 0 + advantages[non_zero_std_mask] = advantages[non_zero_std_mask] / ( + std.unsqueeze(-1)[non_zero_std_mask] + epsilon + ) + return advantages + + def dynamic_sampling( repeated_batch: BatchedDataDict[DatumSpec], std: torch.Tensor, @@ -961,16 +857,11 @@ def scale_rewards( ) # Clamp and scale - def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: - r = torch.clamp(reward_tensor, min=source_min, max=source_max) - return target_min + (r - source_min) / (source_max - source_min) * ( - target_max - target_min - ) - - scaled_rewards = _scale(rewards) + rewards = torch.clamp(rewards, min=source_min, max=source_max) + scaled_rewards = target_min + (rewards - source_min) / ( + source_max - source_min + ) * (target_max - target_min) repeated_batch["total_reward"] = scaled_rewards - for key in get_gdpo_reward_component_keys(repeated_batch): - repeated_batch[key] = _scale(repeated_batch[key]) return repeated_batch @@ -1015,85 +906,6 @@ def _should_use_nemo_gym(master_config: MasterConfig) -> bool: return should_use_nemo_gym -def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool: - env_config = master_config.get("env") or dict() - should_log_nemo_gym_responses = bool( - env_config.get("should_log_nemo_gym_responses") - ) - - return should_log_nemo_gym_responses - - -def _create_advantage_estimator(master_config: MasterConfig): - """Create and return an advantage estimator based on configuration. - - Args: - master_config: The master configuration dictionary. - - Returns: - An advantage estimator instance (GRPO, GDPO, or ReinforcePlusPlus). - - Raises: - ValueError: If the advantage estimator name is not recognized. - """ - grpo_config = master_config["grpo"] - loss_config = master_config["loss_fn"] - - # Provide backward-compatible defaults when adv_estimator is not in config. - # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline - # which older configs still use. - adv_estimator_config = grpo_config.get( - "adv_estimator", - { - "name": "grpo", - "normalize_rewards": grpo_config.get("normalize_rewards", True), - "use_leave_one_out_baseline": grpo_config.get( - "use_leave_one_out_baseline", False - ), - "minus_baseline": True, - }, - ) - - adv_estimator_name = adv_estimator_config["name"] - if adv_estimator_name == "gdpo": - adv_estimator = GDPOAdvantageEstimator(adv_estimator_config, loss_config) - print(" ✓ Using GDPO advantage estimator (multi-reward)") - elif adv_estimator_name == "grpo": - adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) - print(" ✓ Using GRPO advantage estimator") - elif adv_estimator_name == "reinforce_plus_plus": - adv_estimator = ReinforcePlusPlusAdvantageEstimator( - adv_estimator_config, loss_config - ) - print(" ✓ Using Reinforce++ advantage estimator") - else: - raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") - - return adv_estimator - - -def _extract_prompt_only_messages(message_logs: list) -> list: - """Extract only prompt messages (user/system) from message logs. - - This is used to get prompt IDs for advantage estimation, excluding - any assistant responses. - - Args: - message_logs: List of message logs, where each log is a list of messages. - - Returns: - List of message logs containing only user and system messages. - """ - prompt_only_message_logs = [] - for message_log in message_logs: - prompt_only_log = [] - for message in message_log: - if message["role"] == "user" or message["role"] == "system": - prompt_only_log.append(message) - prompt_only_message_logs.append(prompt_only_log) - return prompt_only_message_logs - - def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, @@ -1138,37 +950,16 @@ def refit_policy_generation( policy.get_free_memory_bytes() * float(memory_ratio) ) - if isinstance(policy_generation, SGLangGeneration): - sglang_url_to_gpu_uuids = ( - policy_generation.get_sglang_url_to_gpu_uuids() - ) - # Stream weights via HTTP - flush_success = policy_generation.invalidate_kv_cache() - if not flush_success: - print("SGLang KV cache invalidation failed before weight update. ") - futures_train = policy.stream_weights_via_http( - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, - ) - # Wait for all workers to complete - ray.get(futures_train) - update_success = True - else: - # Original ZMQ IPC path for vLLM - futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes - ) - futures_inference = policy_generation.update_weights_via_ipc_zmq() - # wait for all futures to complete - ray.get(futures_train) - results = ray.get(futures_inference) - update_success = all(result for result in results if result is not None) + futures_train = policy.stream_weights_via_ipc_zmq( + buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales + ) + futures_inference = policy_generation.update_weights_via_ipc_zmq() + # wait for all futures to complete + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) else: # update weights through nccl - # SGLang haven't implemented non-colocated inference mode. - if isinstance(policy_generation, SGLangGeneration): - raise NotImplementedError( - "SGLang haven't implemented non-colocated inference mode. " - ) futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) futures_inference = policy_generation.update_weights_from_collective() # wait for all futures to complete @@ -1191,113 +982,6 @@ def refit_policy_generation( policy_generation.prepare_for_generation(tags=["kv_cache"]) -def _log_mixed_rewards_and_advantages_information( - logger: Logger, - total_steps: int, - metrics: dict[str, Any], - baseline: torch.Tensor, - advantages: torch.Tensor, -) -> None: - # The histograms that are logged are logged with a prefix "train/" to the name, since that is what the remaining metrics will be logged with. - logger.log_histogram( - baseline.numpy(), total_steps + 1, "train/baseline_reward/histogram" - ) - metrics["baseline_reward/pct_0"] = 100 * (baseline == 0).float().mean().item() - metrics["baseline_reward/pct_1"] = 100 * (baseline == 1).float().mean().item() - metrics["baseline_reward/pct_mixed"] = ( - 100 - metrics["baseline_reward/pct_0"] - metrics["baseline_reward/pct_1"] - ) - - logger.log_histogram( - advantages.numpy(), total_steps + 1, "train/advantages/histogram" - ) - metrics["advantages/sum"] = advantages.float().sum().item() - metrics["advantages/mean"] = advantages.float().mean().item() - - -def compute_and_apply_seq_logprob_error_masking( - train_data: BatchedDataDict, - rewards: torch.Tensor, - seq_logprob_error_threshold: Optional[float], -) -> tuple[float, int, float]: - """Compute sequence-level logprob error metrics and optionally mask high-error sequences. - - This function computes the multiplicative probability error per sequence - (same calculation as token_mult_prob_error but aggregated per-sequence) and - optionally masks sequences that exceed the configured threshold. - - Args: - train_data: Training data dict containing token_mask, sample_mask, - prev_logprobs, and generation_logprobs. If masking is applied, - sample_mask will be updated in-place. - rewards: Reward tensor for computing statistics on masked sequences. - seq_logprob_error_threshold: If set, mask sequences with mult_prob_error - exceeding this threshold. If None, only compute metrics. - - Returns: - Tuple of (max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct) - """ - # Compute sequence-level logprob error metrics (always) - token_mask = train_data["token_mask"][:, 1:] - sample_mask = train_data["sample_mask"] - prev_logprobs = train_data["prev_logprobs"][:, 1:] - generation_logprobs = train_data["generation_logprobs"][:, 1:] - lp_error = torch.abs(generation_logprobs - prev_logprobs) - - # Use combined mask exactly as in loss function - mask = token_mask * sample_mask.unsqueeze(-1) - - # Calculate sequence-level multiplicative prob error - # EXACT same calculation as token_mult_prob_error but per-sequence - seq_mult_prob_error = (torch.exp(lp_error * mask) * mask).sum(dim=-1) / mask.sum( - dim=-1 - ).clamp(min=1) - max_seq_mult_prob_error = ( - seq_mult_prob_error.max().item() if seq_mult_prob_error.numel() > 0 else 0.0 - ) - - # Apply sequence-level masking if configured - num_masked_seqs = 0 - masked_correct_pct = 0.0 - - if seq_logprob_error_threshold is not None: - print( - f"▶ Applying sequence-level logprob error masking (threshold={seq_logprob_error_threshold})...", - flush=True, - ) - - original_sample_mask = sample_mask.clone() - - # Create mask for sequences below threshold - seq_error_mask = ( - seq_mult_prob_error <= seq_logprob_error_threshold - ).float() * original_sample_mask - - diff_mask = original_sample_mask - seq_error_mask - num_masked_seqs = int(diff_mask.sum().item()) - - if num_masked_seqs > 0: - diff_mask_bool = diff_mask.bool() - masked_correct_count = (rewards.view(-1)[diff_mask_bool] == 1).sum().item() - masked_correct_pct = masked_correct_count / num_masked_seqs - - # Update sample_mask in train_data - train_data["sample_mask"] = seq_error_mask - - print( - f" Masked {num_masked_seqs} sequences with mult_prob_error > {seq_logprob_error_threshold}", - flush=True, - ) - if num_masked_seqs > 0: - print( - f" • {masked_correct_count}/{num_masked_seqs} masked sequences were correct (reward=1)" - f" → {masked_correct_pct:.2%}", - flush=True, - ) - - return max_seq_mult_prob_error, num_masked_seqs, masked_correct_pct - - # =============================================================================== # Training & Validation # =============================================================================== @@ -1306,7 +990,7 @@ def compute_and_apply_seq_logprob_error_masking( def grpo_train( policy: ColocatablePolicyInterface, policy_generation: Optional[GenerationInterface], - wrapped_dataloader: StatefulDataLoader | MultipleDataloaderWrapper, + dataloader: StatefulDataLoader, val_dataloader: Optional[StatefulDataLoader], tokenizer: TokenizerType, loss_fn: LossFunction, @@ -1316,6 +1000,7 @@ def grpo_train( checkpointer: CheckpointManager, grpo_save_state: GRPOSaveState, master_config: MasterConfig, + processor: Optional[AutoProcessor] = None, ) -> None: """Run GRPO training algorithm.""" timer = Timer() @@ -1324,7 +1009,6 @@ def grpo_train( fit_last_save_time=True, ) timeout.start_iterations() - memory_tracker = MemoryTracker() kv_scales_cache = None # Cache reused for computed kv scales @@ -1336,17 +1020,11 @@ def grpo_train( POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running assert policy_generation is not None # for mypy type check - if master_config["grpo"].get("skip_reference_policy_logprobs_calculation"): - assert master_config["loss_fn"]["reference_policy_kl_penalty"] == 0 - print( - "Reference policy logprob calculation will be skipped since `grpo.skip_reference_policy_logprobs_calculation` is set to True and `loss_fn.reference_policy_kl_penalty` is 0." - ) - # Check if we need to sync KV cache scales # When fallback to policy as the policy_generation, we use getattr to check. sync_kv_scales = getattr(policy_generation, "requires_kv_scale_sync", False) - # common config/state times + # common config/state itmes current_step = grpo_save_state["current_step"] # current step within an epoch total_steps = grpo_save_state["total_steps"] # total steps across all epochs max_num_steps = master_config["grpo"][ @@ -1363,19 +1041,13 @@ def grpo_train( "total_valid_tokens", 0 ) # total valid tokens processed across all epochs; default to 0 for backward compatibility with older checkpoints val_at_start = master_config["grpo"]["val_at_start"] - val_at_end = master_config["grpo"]["val_at_end"] val_period = master_config["grpo"]["val_period"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - # Initialize advantage estimator - adv_estimator = _create_advantage_estimator(master_config) - # Run validation at the start if configured # TODO: Add validation with kv scales if needed if val_at_start and current_step == 0: print("\n🔍 Running initial validation...", flush=True) - memory_tracker.snapshot_start_of_stage("Initial validation", dir()) - if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation(policy, policy_generation, colocated_inference) POLICY_GENERATION_STALE = False @@ -1388,21 +1060,12 @@ def grpo_train( val_task_to_env, step=0, master_config=master_config, - logger=logger, ) policy_generation.finish_generation() logger.log_metrics(val_metrics, current_step, prefix="validation") logger.log_metrics(validation_timings, current_step, prefix="timing/validation") - if master_config["data"]["use_multiple_dataloader"]: - warnings.warn( - "When using multiple dataloaders, MultipleDataloaderWrapper operates as an infinite iterator. " - "As a result, grpo.max_num_epochs will be ignored, and only grpo.max_num_steps will be used. " - "See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details." - ) - while current_epoch < max_num_epochs and total_steps < max_num_steps: - memory_tracker.snapshot_start_of_stage("Preparing batch", dir()) print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") # batch cache is used for DAPO. We store prompts with non-zero standard deviation in this cache. batch_cache: BatchedDataDict[DatumSpec] = None @@ -1410,22 +1073,11 @@ def grpo_train( dynamic_sampling_num_gen_batches = 0 # Run grpo/dapo training loop (single-turn) - for batch in wrapped_dataloader: - # A central place to store logging data that won't be deleted until the loop ends - metrics_logging_data = dict() - metrics = dict() - - if master_config["data"]["use_multiple_dataloader"]: - print( - f"\n{'=' * 25} Step {current_step + 1}/{max_num_steps} {'=' * 25}", - flush=True, - ) - else: - print( - f"\n{'=' * 25} Step {current_step + 1}/{min(len(wrapped_dataloader), max_num_steps)} {'=' * 25}", - flush=True, - ) - + for batch in dataloader: + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_num_steps)} {'=' * 25}", + flush=True, + ) maybe_gpu_profile_step(policy, total_steps + 1) if policy != policy_generation: maybe_gpu_profile_step(policy_generation, total_steps + 1) @@ -1449,7 +1101,6 @@ def grpo_train( input_ids = batched_flat["token_ids"] # Generate responses - this updates the LLMMessageLogType in repeated_batch - memory_tracker.snapshot_start_of_stage("Generation", dir()) print( f"▶ Generating responses for batch of size {repeated_batch.size}...", flush=True, @@ -1501,14 +1152,12 @@ def grpo_train( policy_generation.prepare_for_generation() dynamic_sampling_num_gen_batches += 1 - if dynamic_sampling_num_gen_batches == 1 and hasattr( - policy_generation, "snapshot_step_metrics" - ): - policy_generation.snapshot_step_metrics() with timer.time("generation"): - # Clear logger metrics for each generation step - if policy_generation is not None: - policy_generation.clear_logger_metrics() + # Clear vLLM logger metrics for each generation step + if policy_generation is not None and hasattr( + policy_generation, "clear_vllm_logger_metrics" + ): + policy_generation.clear_vllm_logger_metrics() # Use NeMo-Gym rollouts if enabled. We cascade NeMo-Gym first since NeMo-Gym requires async rollouts. if _should_use_nemo_gym(master_config): generation_config = master_config["policy"]["generation"] @@ -1525,14 +1174,6 @@ def grpo_train( input_ids = nemo_gym_rollout_result.input_ids repeated_batch = nemo_gym_rollout_result.final_batch rollout_metrics = nemo_gym_rollout_result.rollout_metrics - del nemo_gym_rollout_result - - # NeMo Gym responses can be very large and expensive to log. Here we have logic to opt-in to logging. - if not _should_log_nemo_gym_responses(master_config): - for key in list(rollout_metrics): - if "full_result" in key: - rollout_metrics.pop(key) - # Use async rollouts if vLLM async engine is enabled elif _should_use_async_rollouts(master_config): ( @@ -1566,17 +1207,16 @@ def grpo_train( greedy=False, ) policy_generation.finish_generation() - # Collect generation logger metrics for performance reporting after each generation step - # inflight batch sizes and num pending samples are collected from each worker - if policy_generation is not None: - generation_logger_metrics = ( - policy_generation.get_logger_metrics() + # Collect vLLM logger metrics for performance reporting after each generation step + # inflight batch sizes and num pending samples are collected from each vLLM worker + if policy_generation is not None and hasattr( + policy_generation, "get_vllm_logger_metrics" + ): + vllm_logger_metrics = ( + policy_generation.get_vllm_logger_metrics() ) - - metrics_logging_data["mean_gen_tokens_per_sample"] = ( - rollout_metrics["mean_gen_tokens_per_sample"] - ) - logger.log_metrics(rollout_metrics, total_steps + 1, prefix="train") + else: + vllm_logger_metrics = {} repeated_batch = scale_rewards( repeated_batch, master_config["grpo"]["reward_scaling"] @@ -1588,37 +1228,20 @@ def grpo_train( ) # Calculate rewards & advantages - memory_tracker.snapshot_start_of_stage("Processing rewards", dir()) print("▶ Processing rewards...,", flush=True) with timer.time("reward_calculation"): # Extract rewards from final_batch rewards = repeated_batch["total_reward"] print("▶ Computing advantages...", flush=True) - if master_config["grpo"].get("calculate_advantages_on_gpu"): - print("Computing advantages on GPU!") - # Just fix the device id for now - device_id = 0 - baseline, std = calculate_baseline_and_std_per_prompt( - input_ids.cuda(device_id), - rewards.cuda(device_id), - torch.ones_like(rewards).cuda(device_id), - leave_one_out_baseline=master_config["grpo"][ - "use_leave_one_out_baseline" - ], - ) - baseline = baseline.cpu() - std = std.cpu() - else: - baseline, std = calculate_baseline_and_std_per_prompt( - input_ids, - rewards, - torch.ones_like(rewards), - leave_one_out_baseline=master_config["grpo"][ - "use_leave_one_out_baseline" - ], - ) - + baseline, std = calculate_baseline_and_std_per_prompt( + input_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], + ) # Apply dynamic sampling to filter prompts with non-zero std (DAPO algorithm) repeated_batch, is_batch_complete, batch_cache, ds_metrics = ( dynamic_sampling( @@ -1647,28 +1270,13 @@ def grpo_train( # If the current batch is not enough to fill the buffer during dynamic sampling, we update the cache and process the next batch. if not is_batch_complete: continue + advantages = (rewards - baseline).unsqueeze(-1) - gen_step_metrics = {} - if hasattr(policy_generation, "get_step_metrics"): - gen_step_metrics = policy_generation.get_step_metrics() - - # Save baseline for logging (before deletion) - baseline_for_log = baseline.clone() - - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] - ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs - del prompt_batched_flat - del input_ids - del baseline - del std + if master_config["grpo"]["normalize_rewards"]: + advantages = normalize_advantages_with_epsilon( + advantages=advantages, + std=std, + ) with timer.time("data_processing"): use_overlong_filtering = master_config["grpo"]["overlong_filtering"] @@ -1681,7 +1289,7 @@ def grpo_train( loss_multiplier[truncated] = 0 repeated_batch["loss_multiplier"] = loss_multiplier - # Add loss mask to each message in LLMMessageLogType + # Add loss mask and advantages to each message in LLMMessageLogType for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): if message["role"] == "assistant": @@ -1696,6 +1304,9 @@ def grpo_train( message["generation_logprobs"] = torch.zeros_like( message["token_ids"], dtype=torch.float32 ) + message["advantages"] = advantages[i].expand( + message["token_ids"].shape + ) # Convert updated LLMMessageLogType to FlatMessagesType for training flat_messages, input_lengths = batched_message_log_to_flat_message( @@ -1707,101 +1318,35 @@ def grpo_train( ) # Create training data from flattened messages - # Note: advantages will be computed and added after logprobs are available train_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": flat_messages["token_ids"], "input_lengths": input_lengths, + "advantages": flat_messages["advantages"], "generation_logprobs": flat_messages["generation_logprobs"], "token_mask": flat_messages["token_loss_mask"], "sample_mask": repeated_batch["loss_multiplier"], } ) # this will be mini-batched inside the policy, so maintain the packed multimodal structure - # This is also used to populate part of the downstream logprob calculation data - extra_multimodal_data = flat_messages.get_multimodal_dict( - as_tensors=False + train_data.update( + flat_messages.get_multimodal_dict(as_tensors=False) ) - train_data.update(extra_multimodal_data) train_data.to("cpu") - metrics_logging_data["content"] = flat_messages["content"] - - memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) print("▶ Preparing for logprob inference...", flush=True) with timer.time("logprob_inference_prep"): policy.prepare_for_lp_inference() print("▶ Computing logprobs...", flush=True) with timer.time("policy_and_reference_logprobs"): - # Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers. - logprob_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": train_data["input_ids"], - "input_lengths": train_data["input_lengths"], - "token_mask": flat_messages["token_loss_mask"], - "sample_mask": repeated_batch["loss_multiplier"], - **extra_multimodal_data, - } - ) - train_data["prev_logprobs"] = policy.get_logprobs( - logprob_data, timer=timer - )["logprobs"] - - if not master_config["grpo"].get( - "skip_reference_policy_logprobs_calculation" - ): - train_data["reference_policy_logprobs"] = ( - policy.get_reference_policy_logprobs( - logprob_data, - timer=timer, - )["reference_logprobs"] - ) - - del logprob_data - del extra_multimodal_data - - ( - max_seq_mult_prob_error, - num_masked_seqs, - masked_correct_pct, - ) = compute_and_apply_seq_logprob_error_masking( - train_data=train_data, - rewards=rewards, - seq_logprob_error_threshold=master_config["grpo"][ - "seq_logprob_error_threshold" - ], - ) - - # Compute advantages with adv_estimator using correct mask and logprobs - with timer.time("advantage_calculation"): - print("▶ Computing advantages...", flush=True) - # Get token-level mask: token_mask * sample_mask - token_mask = train_data["token_mask"] - sample_mask = train_data["sample_mask"] - mask = token_mask * sample_mask.unsqueeze(-1) - - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - repeated_batch=repeated_batch, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) - del prompt_ids_for_adv - - # Log rewards and advantages information - _log_mixed_rewards_and_advantages_information( - logger=logger, - total_steps=total_steps, - metrics=metrics, - baseline=baseline_for_log, - advantages=train_data["advantages"], - ) - del baseline_for_log + fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] + reference_logprobs = policy.get_reference_policy_logprobs( + train_data + )["reference_logprobs"] + train_data["prev_logprobs"] = fprop_logprobs + train_data["reference_policy_logprobs"] = reference_logprobs - memory_tracker.snapshot_start_of_stage("Policy train", dir()) print("▶ Preparing for training...", flush=True) with timer.time("training_prep"): policy.prepare_for_training() # set model train and reload optim to GPU @@ -1809,11 +1354,7 @@ def grpo_train( print("▶ Training policy...", flush=True) with timer.time("policy_training"): - train_results = policy.train( - train_data, - loss_fn, - timer=timer, - ) + train_results = policy.train(train_data, loss_fn) # Recompute KV scales after policy training if needed if sync_kv_scales: @@ -1828,18 +1369,13 @@ def grpo_train( # Set generation as stale to force refit with new scales POLICY_GENERATION_STALE = True - is_last_step = total_steps + 1 >= max_num_steps - if not master_config["data"]["use_multiple_dataloader"]: - is_last_step = is_last_step or ( - (current_epoch + 1 == max_num_epochs) - and (current_step + 1 == len(wrapped_dataloader)) - ) + is_last_step = (total_steps + 1 >= max_num_steps) or ( + (current_epoch + 1 == max_num_epochs) + and (current_step + 1 == len(dataloader)) + ) - # Run validation if it's a validation step or last step with val_at_end - if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( - val_at_end and is_last_step - ): - memory_tracker.snapshot_start_of_stage("Validation", dir()) + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( policy, @@ -1859,7 +1395,6 @@ def grpo_train( val_task_to_env, step=total_steps + 1, master_config=master_config, - logger=logger, ) policy_generation.finish_generation() logger.log_metrics( @@ -1870,7 +1405,7 @@ def grpo_train( ) # Get flat advantages and token mask for masked metrics computation - flat_advantages = train_data["advantages"] + flat_advantages = flat_messages["advantages"] flat_token_mask = flat_messages["token_loss_mask"] # Filter advantages using token mask (only valid response tokens) @@ -1878,9 +1413,7 @@ def grpo_train( flat_advantages, flat_token_mask.bool() ) - memory_tracker.snapshot_start_of_stage("Metrics", dir()) metrics = { - **metrics, "loss": train_results["loss"].numpy(), "grad_norm": train_results["grad_norm"].numpy(), "reward": rewards.numpy(), @@ -1907,7 +1440,6 @@ def grpo_train( metrics["reward"] = repeated_batch["total_reward"].numpy() metrics.update(train_results["all_mb_metrics"]) - metrics.update(gen_step_metrics) for k, v in metrics.items(): if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: valid_values = [x for x in v if not np.isinf(x)] @@ -1929,20 +1461,13 @@ def grpo_train( "mean_prompt_length", }: metrics[k] = np.mean(v).item() - elif isinstance(v, (np.ndarray, list)): - metrics[k] = np.sum(v).item() else: - print(f"Skipping aggregation for {k} ({type(v)})") + metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) - metrics["generation_logger_metrics"] = generation_logger_metrics + metrics["vllm_logger_metrics"] = vllm_logger_metrics total_valid_tokens += metrics["global_valid_toks"] - # Always log sequence-level error metrics (useful for deciding threshold) - metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error - metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs - metrics["masked_correct_pct"] = masked_correct_pct - ## Checkpointing consumed_samples += master_config["grpo"]["num_prompts_per_step"] timeout.mark_iteration() @@ -1956,7 +1481,6 @@ def grpo_train( # Check if timeout-based checkpointing is enabled in config. should_save_by_timeout = timeout.check_save() - memory_tracker.snapshot_start_of_stage("Checkpointing", dir()) if master_config["checkpointing"]["enabled"] and ( should_save_by_step or should_save_by_timeout ): @@ -2016,60 +1540,32 @@ def grpo_train( ), optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" - ) - if checkpointer.save_optimizer - else None, + ), tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), checkpointing_cfg=master_config["checkpointing"], ) - if master_config["data"]["use_multiple_dataloader"]: - for ( - task_name, - task_dataloader, - ) in wrapped_dataloader.dataloaders.items(): - torch.save( - task_dataloader.state_dict(), - os.path.join( - checkpoint_path, - f"train_dataloader_{task_name}.pt", - ), - ) - else: - torch.save( - wrapped_dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), - ) + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) checkpointer.finalize_checkpoint(checkpoint_path) # Logging # Log training data - memory_tracker.snapshot_start_of_stage("Logging", dir()) - if not _should_log_nemo_gym_responses(master_config): - log_data = {} - if "agent_ref" in repeated_batch: - log_data["agent_ref"] = repeated_batch["agent_ref"] - log_data["content"] = flat_messages["content"] - log_data["rewards"] = rewards.tolist() - if master_config["grpo"]["use_dynamic_sampling"]: - log_data["filtered_rewards"] = rewards.tolist() - log_data["rewards"] = repeated_batch["total_reward"].tolist() - log_data["input_lengths"] = input_lengths.tolist() - log_data["token_ids"] = train_data["input_ids"].tolist() - log_data["token_loss_mask"] = train_data["token_mask"].tolist() - log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() - log_data["advantages"] = train_data["advantages"].tolist() - log_data["generation_logprobs"] = train_data[ - "generation_logprobs" - ].tolist() - log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() - - logger.log_batched_dict_as_jsonl( - log_data, f"train_data_step{total_steps + 1}.jsonl" - ) - del log_data - del flat_messages + log_data = {"content": flat_messages["content"]} + log_data["rewards"] = rewards.tolist() + if master_config["grpo"]["use_dynamic_sampling"]: + log_data["filtered_rewards"] = rewards.tolist() + log_data["rewards"] = repeated_batch["total_reward"].tolist() + + log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() + log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{total_steps + 1}.jsonl" + ) timing_metrics: dict[str, float] = timer.get_timing_metrics( reduction_op="sum" @@ -2088,12 +1584,11 @@ def grpo_train( total_steps + 1, name="train/token_mult_prob_error_plot_sample", ) - del train_data if master_config["policy"]["generation"].get("vllm_cfg", {}).get( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): log_generation_metrics_to_wandb( - generation_logger_metrics, + vllm_logger_metrics, total_steps + 1, master_config["policy"]["generation"]["vllm_cfg"][ "vllm_metrics_logger_interval" @@ -2118,8 +1613,6 @@ def grpo_train( print("\n📊 Training Results:") print(f" • Loss: {metrics['loss']:.4f}") - if "draft_loss" in metrics: - print(f" • Draft Loss: {metrics['draft_loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") if master_config["grpo"]["use_dynamic_sampling"]: print(f" • Avg Filtered Reward: {np.mean(rewards.numpy()):.4f}") @@ -2129,7 +1622,7 @@ def grpo_train( else: print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print( - f" • Mean Generation Length: {metrics_logging_data['mean_gen_tokens_per_sample']:.4f}", + f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}", flush=True, ) @@ -2167,39 +1660,19 @@ def grpo_train( logger.log_metrics( performance_metrics, total_steps + 1, prefix="performance" ) - # step_finished=True here since this is the final log of our current step. - logger.log_metrics( - timing_metrics, - total_steps + 1, - prefix="timing/train", - step_finished=True, - ) + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") # Reset the batch and set dynamic_sampling_num_gen_batches to 0 batch_cache = None dynamic_sampling_num_gen_batches = 0 - # Clear mem - memory_tracker.snapshot_start_of_stage("After CPU memory clear", dir()) - - # processing rewards - del repeated_batch - del rewards - # train_data already deleted after logging above - # logging - del metrics - if "val_metrics" in dir(): - del val_metrics - timer.reset() current_step += 1 total_steps += 1 if should_save_by_timeout: - memory_tracker.snapshot_start_of_stage("", dir()) print("Timeout has been reached, stopping training early", flush=True) return if total_steps >= max_num_steps: - memory_tracker.snapshot_start_of_stage("", dir()) print( "Max number of steps has been reached, stopping training early", flush=True, @@ -2217,7 +1690,6 @@ def validate( val_task_to_env: Optional[dict[str, EnvironmentInterface]], step: int, master_config: MasterConfig, - logger: Optional[Logger] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: """Run validation on the validation dataset.""" if val_dataloader is None: @@ -2344,14 +1816,6 @@ def validate( validation_time = timing_metrics.get("total_validation_time", 0) print(f" • Total validation time: {validation_time:.2f}s", flush=True) - # Log validation data to JSONL file - if logger is not None: - val_log_data = { - "content": all_message_logs, - "rewards": total_rewards, - } - logger.log_batched_dict_as_jsonl(val_log_data, f"val_data_step{step}.jsonl") - # Make sure to reset the timer after validation timer.reset() @@ -2439,12 +1903,8 @@ def async_grpo_train( ) # Default to 0 for backward compatibility with older checkpoints val_period = master_config["grpo"]["val_period"] val_at_start = master_config["grpo"]["val_at_start"] - val_at_end = master_config["grpo"]["val_at_end"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] - # Initialize advantage estimator - adv_estimator = _create_advantage_estimator(master_config) - assert not colocated_inference, ( "Colocated inference is not supported for async GRPO. Please use non-colocated inference." ) @@ -2580,7 +2040,6 @@ def async_grpo_train( val_task_to_env, step=0, master_config=master_config, - logger=logger, ) policy_generation.finish_generation() logger.log_metrics(val_metrics, step, prefix="validation") @@ -2597,9 +2056,12 @@ def async_grpo_train( trajectory_collector.resume.remote() print("✅ All setup complete, starting buffer wait...") - # Clear logger metrics at start of training - if policy_generation is not None: - policy_generation.clear_logger_metrics() + + # Clear vLLM logger metrics after at start of training + if policy_generation is not None and hasattr( + policy_generation, "clear_vllm_logger_metrics" + ): + policy_generation.clear_vllm_logger_metrics() # Wait for initial buffer fill print( @@ -2722,27 +2184,59 @@ def async_grpo_train( print("▶ Processing rewards...") with timer.time("reward_calculation"): - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] - ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, + prompt_only_message_logs = [] + for message_log in repeated_batch["message_log"]: + prompt_only_log = [] + for message in message_log: + if message["role"] == "user" or message["role"] == "system": + prompt_only_log.append(message) + prompt_only_message_logs.append(prompt_only_log) + + prompt_batched_flat, prompt_input_lengths = ( + batched_message_log_to_flat_message( + prompt_only_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) ) - prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs - del prompt_batched_flat + prompt_only_ids = prompt_batched_flat["token_ids"] rewards = repeated_batch["total_reward"] + print("▶ Computing advantages...") + + baseline, std = calculate_baseline_and_std_per_prompt( + prompt_only_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], + ) + advantages = (rewards - baseline).unsqueeze(-1) + print( f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" ) + print( + f" 📊 Baseline stats: min={baseline.min():.4f}, max={baseline.max():.4f}, mean={baseline.mean():.4f}" + ) + print( + f" 📊 Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" + ) + + if master_config["grpo"]["normalize_rewards"]: + advantages = normalize_advantages_with_epsilon( + advantages=advantages, + std=std, + ) + + print( + f" 📊 Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" + ) # Prepare training data (same as sync version) with timer.time("data_processing"): - # Add loss mask to each message + # Add loss mask and advantages to each message for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): if message["role"] == "assistant": @@ -2757,6 +2251,9 @@ def async_grpo_train( message["generation_logprobs"] = torch.zeros_like( message["token_ids"], dtype=torch.float32 ) + message["advantages"] = advantages[i].expand( + message["token_ids"].shape + ) # Convert to flat format for training flat_messages, input_lengths = batched_message_log_to_flat_message( @@ -2768,11 +2265,11 @@ def async_grpo_train( ) # Create training data - # Note: advantages will be computed and added after logprobs are available train_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": flat_messages["token_ids"], "input_lengths": input_lengths, + "advantages": flat_messages["advantages"], "generation_logprobs": flat_messages["generation_logprobs"], "token_mask": flat_messages["token_loss_mask"], "sample_mask": repeated_batch["loss_multiplier"], @@ -2787,57 +2284,13 @@ def async_grpo_train( print("▶ Computing logprobs...") with timer.time("policy_and_reference_logprobs"): - fprop_logprobs = policy.get_logprobs( - train_data, - timer=timer, - )["logprobs"] + fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] reference_logprobs = policy.get_reference_policy_logprobs( - train_data, - timer=timer, + train_data )["reference_logprobs"] train_data["prev_logprobs"] = fprop_logprobs train_data["reference_policy_logprobs"] = reference_logprobs - ( - max_seq_mult_prob_error, - num_masked_seqs, - masked_correct_pct, - ) = compute_and_apply_seq_logprob_error_masking( - train_data=train_data, - rewards=rewards, - seq_logprob_error_threshold=master_config["grpo"][ - "seq_logprob_error_threshold" - ], - ) - - # Compute advantages with adv_estimator using correct mask and logprobs - with timer.time("advantage_calculation"): - print("▶ Computing advantages...", flush=True) - # Get token-level mask: token_mask * sample_mask - token_mask = train_data["token_mask"] - sample_mask = train_data["sample_mask"] - mask = token_mask * sample_mask.unsqueeze(-1) - - train_data["advantages"] = adv_estimator.compute_advantage( - prompt_ids=prompt_ids_for_adv, - rewards=rewards, - mask=mask, - repeated_batch=repeated_batch, - logprobs_policy=train_data["prev_logprobs"], - logprobs_reference=train_data.get("reference_policy_logprobs"), - ) - del prompt_ids_for_adv - - # Log advantages stats - # Note: For GRPOAdvantageEstimator with normalize_rewards=True, these are - # already normalized advantages (equivalent to "Normalized advantages stats" - # in older versions). For ReinforcePlusPlusAdvantageEstimator, advantages - # are globally normalized across valid tokens. - advantages = train_data["advantages"] - print( - f" 📊 Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" - ) - print("▶ Preparing for training...") with timer.time("training_prep"): policy.prepare_for_training() @@ -2845,26 +2298,26 @@ def async_grpo_train( print("▶ Training policy...") with timer.time("policy_training"): - train_results = policy.train( - train_data, - loss_fn, - timer=timer, - ) + train_results = policy.train(train_data, loss_fn) print("🔄 Synchronizing policy weights to trajectory collector…") - generation_logger_metrics = None + vllm_logger_metrics = None if NEED_REFIT: # Measure pending-generation wait as exposed_generation time print("🔄 Coordinating with trajectory collector before refit...") with timer.time("exposed_generation"): ray.get(trajectory_collector.prepare_for_refit.remote()) - # Collect generation logger metrics for performance reporting - # inflight batch sizes and num pending samples are collected from each worker - if policy_generation is not None: - generation_logger_metrics = ( - policy_generation.get_logger_metrics() + # Collect vLLM logger metrics for performance reporting + # inflight batch sizes and num pending samples are collected from each vLLM worker + if policy_generation is not None and hasattr( + policy_generation, "get_vllm_logger_metrics" + ): + vllm_logger_metrics = ( + policy_generation.get_vllm_logger_metrics() ) + else: + vllm_logger_metrics = {} # Only the actual refit/weight transfer should be counted as weight_sync print("🔄 Performing policy generation refit...") @@ -2879,18 +2332,17 @@ def async_grpo_train( trajectory_collector.set_weight_version.remote(weight_version) trajectory_collector.resume_after_refit.remote() - # Clear logger metrics after each refit (weight sync), starting a new logging cycle - if policy_generation is not None: - policy_generation.clear_logger_metrics() + # Clear vLLM logger metrics after each refit (weight sync), starting a new logging cycle + if policy_generation is not None and hasattr( + policy_generation, "clear_vllm_logger_metrics" + ): + policy_generation.clear_vllm_logger_metrics() # Validation val_metrics, validation_timings = None, None is_last_step = step + 1 == master_config["grpo"]["max_num_steps"] - # Run validation if it's a validation step or last step with val_at_end - if (val_period > 0 and (step + 1) % val_period == 0) or ( - val_at_end and is_last_step - ): + if val_period > 0 and (step + 1) % val_period == 0: # Pause trajectory collection during validation to reduce memory pressure trajectory_collector.pause.remote() @@ -2908,7 +2360,6 @@ def async_grpo_train( val_task_to_env, step=step + 1, master_config=master_config, - logger=logger, ) policy_generation.finish_generation() logger.log_metrics( @@ -2925,11 +2376,8 @@ def async_grpo_train( # Resume trajectory collection after validation trajectory_collector.resume.remote() # Get flat advantages and token mask for masked metrics computation - flat_advantages = train_data["advantages"] + flat_advantages = flat_messages["advantages"] flat_token_mask = flat_messages["token_loss_mask"] - # Save content for logging before deleting flat_messages - flat_messages_content = flat_messages.get("content", []) - del flat_messages # Filter advantages using token mask (only valid response tokens) response_advantages = torch.masked_select( @@ -2981,15 +2429,10 @@ def async_grpo_train( else: metrics[k] = np.sum(v).item() metrics.update(rollout_metrics) - if generation_logger_metrics is not None: - metrics["generation_logger_metrics"] = generation_logger_metrics + if vllm_logger_metrics is not None: + metrics["vllm_logger_metrics"] = vllm_logger_metrics total_valid_tokens += metrics["global_valid_toks"] - # Always log sequence-level error metrics (useful for deciding threshold) - metrics["max_seq_mult_prob_error"] = max_seq_mult_prob_error - metrics["num_masked_seqs_by_logprob_error"] = num_masked_seqs - metrics["masked_correct_pct"] = masked_correct_pct - # Checkpointing (same as sync version) consumed_samples += master_config["grpo"]["num_prompts_per_step"] timeout.mark_iteration() @@ -3005,6 +2448,8 @@ def async_grpo_train( if master_config["checkpointing"]["enabled"] and ( should_save_by_step or should_save_by_timeout ): + policy.prepare_for_training() + grpo_save_state["current_step"] = step + 1 grpo_save_state["total_valid_tokens"] = total_valid_tokens if val_metrics is not None: @@ -3053,9 +2498,7 @@ def async_grpo_train( ), optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" - ) - if checkpointer.save_optimizer - else None, + ), tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), @@ -3070,30 +2513,16 @@ def async_grpo_train( os.path.join(checkpoint_path, "train_dataloader.pt"), ) checkpointer.finalize_checkpoint(checkpoint_path) + policy.offload_after_refit() - # Logging - # Log training data (match sync GRPO logging payload for parity) - log_data = {} - if "agent_ref" in repeated_batch: - log_data["agent_ref"] = repeated_batch["agent_ref"] - log_data["content"] = flat_messages_content + log_data = {"content": flat_messages["content"]} log_data["rewards"] = rewards.tolist() - if master_config["grpo"]["use_dynamic_sampling"]: - # In dynamic sampling, `rewards` corresponds to filtered rewards - log_data["filtered_rewards"] = rewards.tolist() - log_data["rewards"] = repeated_batch["total_reward"].tolist() - log_data["input_lengths"] = input_lengths.tolist() - log_data["token_ids"] = train_data["input_ids"].tolist() - log_data["token_loss_mask"] = train_data["token_mask"].tolist() - log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() - log_data["advantages"] = train_data["advantages"].tolist() log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + log_data["input_lengths"] = input_lengths.tolist() logger.log_batched_dict_as_jsonl( log_data, f"train_data_step{step + 1}.jsonl" ) - del train_data - del flat_messages_content timing_metrics: dict[str, float] = timer.get_timing_metrics( reduction_op="sum" @@ -3108,7 +2537,7 @@ def async_grpo_train( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): log_generation_metrics_to_wandb( - generation_logger_metrics, + vllm_logger_metrics, step + 1, master_config["policy"]["generation"]["vllm_cfg"][ "vllm_metrics_logger_interval" @@ -3132,8 +2561,6 @@ def async_grpo_train( print("\n📊 Training Results:") print(f" • Loss: {metrics['loss']:.4f}") - if "draft_loss" in metrics: - print(f" • Draft Loss: {metrics['draft_loss']:.4f}") print(f" • Generation KL Error: {metrics['gen_kl_error']:.4f}") print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}") print(f" • Buffer Size: {buffer_size_current}") diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index df6ff6bc54..557681b4ca 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -11,74 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, NotRequired, Optional, TypedDict, TypeVar +import math +import sys +if sys.version_info >= (3, 11): + from typing import Any, NotRequired, Optional, TypedDict, TypeVar +else: + from typing import Any, Optional, TypedDict, TypeVar + from typing_extensions import NotRequired import torch +import torch.distributed -from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType +from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import DistributedCrossEntropy +from nemo_rl.distributed.model_utils import ( + ChunkedDistributedEntropy, + ChunkedDistributedGatherLogprob, + _compute_distributed_log_softmax, + _get_tokens_on_this_cp_rank, + allgather_cp_sharded_tensor, + from_parallel_logits_to_logprobs, + gather_logits_at_global_indices, + get_logprobs_from_vocab_parallel_logits, +) +from nemo_rl.models.policy.utils import rebuild_cuda_tensor_from_ipc Tensor = TypeVar("Tensor", bound=torch.Tensor) -class DraftCrossEntropyLossConfig(TypedDict): - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] - - -class DraftCrossEntropyLossDataDict(TypedDict): - teacher_logits: Tensor - student_logits: Tensor - token_mask: Tensor - sample_mask: Tensor - student_vocab_indices: NotRequired[Tensor] - - -class DraftCrossEntropyLossFn(LossFunction): - """Compute the auxiliary soft-target cross-entropy used for draft-model training.""" - - loss_type = LossType.TOKEN_LEVEL - input_type = LossInputType.DRAFT - - def __init__( - self, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - ): - self.vocab_parallel_group = vocab_parallel_group - - def __call__( - self, - teacher_logits: Tensor, - student_logits: Tensor, - token_mask: Tensor, - data: BatchedDataDict[DraftCrossEntropyLossDataDict], - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - ) -> torch.Tensor: - """Reduce the masked per-token draft loss to a scalar.""" - if self.vocab_parallel_group is not None: - # Soft cross entropy matches the forward-KL student gradient. - per_token_loss = DistributedCrossEntropy.apply( - student_logits, - teacher_logits, - self.vocab_parallel_group, - False, - ) - else: - teacher_probs = torch.nn.functional.softmax(teacher_logits, dim=-1) - student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1) - per_token_loss = -(teacher_probs * student_log_probs).sum(dim=-1) - - mask = token_mask * data["sample_mask"].unsqueeze(-1) - return masked_mean( - per_token_loss, - mask, - global_normalization_factor=global_valid_toks, - ) - - class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float reference_policy_kl_type: str @@ -91,13 +52,6 @@ class ClippedPGLossConfig(TypedDict): use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool truncated_importance_sampling_ratio: float | None - # Type of truncated importance sampling: - # "tis" – clamp IS weights to max - # "icepop" – zero out tokens with IS weight outside [min, max] - # "seq-mask-tis" – zero out sequences by geometric-mean IS ratio, non-truncated token IS correction - truncated_importance_sampling_type: NotRequired[str | None] - # Lower bound for ICE-POP / seq-mask-tis filtering - truncated_importance_sampling_ratio_min: NotRequired[float | None] token_level_loss: bool # If True, apply the off-policy importance-sampling correction at the # sequence level (one weight per generated sample), as in GSPO. @@ -110,8 +64,6 @@ class ClippedPGLossConfig(TypedDict): # NOTE: This should only be used when doing exactly one update per rollout # (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size) force_on_policy_ratio: NotRequired[bool] - # If True, add KL penalty to reward instead of loss (used by Reinforce++) - use_kl_in_reward: NotRequired[bool] class ClippedPGLossDataDict(TypedDict): @@ -168,8 +120,6 @@ class ClippedPGLossFn(LossFunction): Due to potential numerical instability, we cast the logits to float32 before computing the loss. """ - input_type = LossInputType.LOGPROB - def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_min = cfg["ratio_clip_min"] self.ratio_clip_max = cfg["ratio_clip_max"] @@ -189,14 +139,6 @@ def __init__(self, cfg: ClippedPGLossConfig): self.truncated_importance_sampling_ratio = cfg[ "truncated_importance_sampling_ratio" ] - # Type of truncated importance sampling: "tis" | "icepop" | "seq-mask-tis" - self.truncated_importance_sampling_type = cfg.get( - "truncated_importance_sampling_type" - ) - # Lower bound for ICE-POP / seq-mask-tis filtering - self.truncated_importance_sampling_ratio_min = cfg.get( - "truncated_importance_sampling_ratio_min" - ) # Whether to compute importance weights per-sequence instead of per-token. self.sequence_level_importance_ratios = cfg.get( "sequence_level_importance_ratios", @@ -216,53 +158,25 @@ def __init__(self, cfg: ClippedPGLossConfig): assert self.truncated_importance_sampling_ratio > 0, ( "truncated_importance_sampling_ratio should be positive" ) - assert self.truncated_importance_sampling_type in ( - "tis", - "icepop", - "seq-mask-tis", - ), ( - f"truncated_importance_sampling_type must be 'tis', 'icepop', or 'seq-mask-tis', " - f"got {self.truncated_importance_sampling_type}" - ) - if self.truncated_importance_sampling_type == "seq-mask-tis": - assert not self.sequence_level_importance_ratios, ( - "seq-mask-tis uses token-level IS correction with sequence-level masking, " - "and is incompatible with sequence_level_importance_ratios=True" - ) - else: - # Warn user that TIS-related parameters are ignored when truncated_importance_sampling_ratio is not set - ignored_params = [] - if cfg.get("truncated_importance_sampling_type") is not None: - ignored_params.append("truncated_importance_sampling_type") - if cfg.get("truncated_importance_sampling_ratio_min") is not None: - ignored_params.append("truncated_importance_sampling_ratio_min") - if ignored_params: - print( - f"[WARN] truncated_importance_sampling_ratio is not set, so the following " - f"parameters are ignored: {', '.join(ignored_params)}. " - f"Set truncated_importance_sampling_ratio to enable truncated importance sampling.", - flush=True, - ) def __call__( self, - next_token_logprobs: Tensor, + next_token_logits: Tensor, data: BatchedDataDict[ClippedPGLossDataDict], global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" - curr_logprobs = next_token_logprobs token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] advantages = data["advantages"][:, 1:] prev_logprobs = data["prev_logprobs"][:, 1:] generation_logprobs = data["generation_logprobs"][:, 1:] - if self.reference_policy_kl_penalty != 0: - reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] - curr_logprobs_unfiltered = data.get( - "curr_logprobs_unfiltered", curr_logprobs - ) + reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] + seq_index = data.get("seq_index", None) mask = token_mask * sample_mask.unsqueeze(-1) @@ -330,41 +244,62 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + curr_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + data["input_ids"], + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + curr_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"], seq_index=seq_index + ) + else: + next_token_logits_wo_last = next_token_logits[ + :, :-1 + ] # Remove last position's logits + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + curr_logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: - # When top-k/top-p filtering is enabled, we need special handling for KL: - # - reference_policy_logprobs is computed **without** filtering (see use_reference_model) - # - curr_logprobs/prev_logprobs are computed **with** filtering (for actor loss compatibility) - # - For KL, we need curr_logprobs **without** filtering to be consistent with ref logprobs - # - For importance weights, we also use unfiltered curr_logprobs_unfiltered since we're - # reweighting samples from π_gen_filtered to π_curr_unfiltered - - # On-policy KL approximation if self.use_on_policy_kl_approximation: # See: docs/guides/grpo.md#on-policy-kl-approximation kl_importance_weights = torch.exp( - curr_logprobs_unfiltered - generation_logprobs + curr_logprobs - generation_logprobs ).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) else: - kl_importance_weights = torch.ones_like(curr_logprobs_unfiltered) - - # Compute KL loss + kl_importance_weights = torch.ones_like(curr_logprobs) kl = ( kl_importance_weights * self.reference_policy_kl_penalty * calculate_kl( - logprobs=curr_logprobs_unfiltered, + logprobs=curr_logprobs, logprobs_reference=reference_policy_logprobs, kl_type=self.reference_policy_kl_type, input_clamp_value=self.kl_input_clamp_value, output_clamp_value=self.kl_output_clamp_value, ) ) - - # Reduce KL loss if self.loss_type == LossType.TOKEN_LEVEL: kl = masked_mean( kl, mask, global_normalization_factor=global_valid_toks @@ -423,7 +358,6 @@ def __call__( # ------------------------------------------------------------- # Off-policy (actor) importance-sampling correction # ------------------------------------------------------------- - _is_filter_metrics: dict = {} # populated for icepop / seq-mask-tis # See: docs/guides/grpo.md#importance-sampling-correction if self.sequence_level_importance_ratios: # importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour)) @@ -442,87 +376,12 @@ def __call__( actor_importance_weights_expanded = torch.nan_to_num( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) - # ---- Truncated Importance Sampling ---- - # "tis" – clamp IS weights to [0, max] - # "icepop" – zero out tokens whose IS weight ∉ [min, max] (ref bounds: 0.5–5) - # "seq-mask-tis" – zero out entire sequences whose geometric-mean - # IS ratio ∉ [min, max]; retained sequences keep - # raw (non-truncated) token-level IS weights (ref bounds: 0.999–1.002) - # Blog: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + # TIS see https://fengyao.notion.site/off-policy-rl if self.truncated_importance_sampling_ratio is not None: - if self.truncated_importance_sampling_type == "tis": - token_in_bounds = ( - actor_importance_weights_expanded - <= self.truncated_importance_sampling_ratio - ) - _is_filter_metrics = { - "is_oob_ratio": 1.0 - - masked_mean( - token_in_bounds.float(), - mask, - global_normalization_factor=global_valid_toks, - ).item(), - } - actor_importance_weights_expanded = torch.clamp( - actor_importance_weights_expanded, - max=self.truncated_importance_sampling_ratio, - ) - elif self.truncated_importance_sampling_type == "icepop": - token_kept_mask = ( - actor_importance_weights_expanded - >= self.truncated_importance_sampling_ratio_min - ) & ( - actor_importance_weights_expanded - <= self.truncated_importance_sampling_ratio - ) - _is_filter_metrics = { - "is_oob_ratio": 1.0 - - masked_mean( - token_kept_mask.float(), - mask, - global_normalization_factor=global_valid_toks, - ).item(), - } - actor_importance_weights_expanded = torch.where( - token_kept_mask, - actor_importance_weights_expanded, - torch.zeros_like(actor_importance_weights_expanded), - ) - elif self.truncated_importance_sampling_type == "seq-mask-tis": - # geo_mean_i = exp( mean_t( log(π_prev / π_gen) ) ) - log_is_ratio = torch.nan_to_num( - prev_logprobs - generation_logprobs, - nan=0.0, - posinf=0.0, - neginf=0.0, - ) - seq_log_is_ratio_mean = masked_mean( - log_is_ratio, token_mask, dim=-1 - ) # [B] - seq_geomean_is_ratio = torch.exp(seq_log_is_ratio_mean).detach() # [B] - seq_kept_mask = ( - ( - seq_geomean_is_ratio - >= self.truncated_importance_sampling_ratio_min - ) - & (seq_geomean_is_ratio <= self.truncated_importance_sampling_ratio) - ).float() # [B] - _is_filter_metrics = { - "is_oob_ratio": 1.0 - - masked_mean( - seq_kept_mask, - sample_mask, - global_normalization_factor=global_valid_seqs, - ).item(), - } - actor_importance_weights_expanded = ( - actor_importance_weights_expanded * seq_kept_mask.unsqueeze(-1) - ) - else: - raise ValueError( - f"Invalid truncated importance sampling type: {self.truncated_importance_sampling_type}" - ) - + actor_importance_weights_expanded = torch.clamp( + actor_importance_weights_expanded, + max=self.truncated_importance_sampling_ratio, + ) actor_importance_weights = actor_importance_weights_expanded del actor_importance_weights_expanded if self.use_importance_sampling_correction: @@ -621,26 +480,24 @@ def __call__( "sampling_importance_ratio": sample_importance_ratio.item(), "num_valid_samples": sample_mask.sum().item(), "approx_entropy": seq_entropy_approx.item(), - **_is_filter_metrics, }, ) -class NLLLossFn(LossFunction): +class NLLLoss(LossFunction): """Negative Log Likelihood Loss function.""" loss_type = LossType.TOKEN_LEVEL - input_type = LossInputType.LOGPROB - - def __init__(self, use_linear_ce_fusion: bool = False): - self.use_linear_ce_fusion = use_linear_ce_fusion def __call__( self, - next_token_logprobs: Tensor, + next_token_logits: Tensor, data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, ) -> tuple[torch.Tensor, dict[str, Any]]: @@ -649,19 +506,52 @@ def __call__( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) + seq_index = data.get("seq_index", None) + + next_token_logits = next_token_logits.to(torch.float32) + + # Gather the logprobs for the actual next tokens + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + token_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + data["input_ids"], + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"], seq_index=seq_index + ) + else: + next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) if dpo_loss: ## shape: [batch_size] num_unmasked_tokens = torch.sum(mask, -1) ## multiply by sample_mask to zero out invalid samples - loss = -torch.sum(next_token_logprobs * mask, dim=-1) + loss = -torch.sum(token_logprobs * mask, dim=-1) if dpo_average_log_probs: loss = loss / num_unmasked_tokens.clamp(min=1) else: ## single scalar loss ## scale by the total number of tokens in the batch loss = -masked_mean( - next_token_logprobs, + token_logprobs, mask, global_normalization_factor=global_valid_toks, ) @@ -681,7 +571,7 @@ class PreferenceLossDataDict(TypedDict): sample_mask: torch.Tensor -class PreferenceLossFn(LossFunction): +class PreferenceLoss(LossFunction): """Preference Loss function. Optimizes the model to prefer chosen responses over rejected ones @@ -702,8 +592,8 @@ class PreferenceLossFn(LossFunction): - accuracy: Fraction of examples where chosen response has higher reward """ - loss_type = LossType.SEQUENCE_LEVEL - input_type = LossInputType.LOGIT + def __init__(self): + self.loss_type = LossType.SEQUENCE_LEVEL def split_output_tensor(self, tensor: Tensor) -> tuple[Tensor, Tensor]: # tensor is of shape (2*micro_batch_size,) @@ -749,14 +639,14 @@ def _preference_loss( def __call__( self, - logits: Tensor, + rewards: Tensor, data: BatchedDataDict[PreferenceLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, ) -> tuple[torch.Tensor, dict[str, Any]]: sample_mask = data["sample_mask"] - rewards = logits.squeeze(-1) + rewards = rewards.squeeze(-1) ( preference_loss, @@ -794,7 +684,7 @@ class DPOLossDataDict(TypedDict): sample_mask: torch.Tensor -class DPOLossFn(PreferenceLossFn): +class DPOLossFn(PreferenceLoss): """Direct Preference Optimization (DPO) loss function. This loss function implements the DPO algorithm as described in: @@ -850,30 +740,63 @@ class DPOLossFn(PreferenceLossFn): - accuracy: Fraction of examples where chosen response has higher reward """ - loss_type = LossType.SEQUENCE_LEVEL - input_type = LossInputType.LOGPROB - - def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False): + def __init__(self, cfg: DPOLossConfig): self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.preference_loss_weight = cfg["preference_loss_weight"] self.sft_loss_weight = cfg["sft_loss_weight"] self.preference_average_log_probs = cfg["preference_average_log_probs"] self.sft_average_log_probs = cfg["sft_average_log_probs"] - self.use_linear_ce_fusion = use_linear_ce_fusion - self.sft_loss = NLLLossFn(use_linear_ce_fusion=use_linear_ce_fusion) + self.sft_loss = NLLLoss() + + self.loss_type = LossType.SEQUENCE_LEVEL def _dpo_loss( self, - next_token_logprobs: Tensor, + next_token_logits: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: - ## TODO(@ashors): there's some duplicate code here with the NLLLossFn function. We should refactor + ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] + seq_index = data.get("seq_index", None) + + next_token_logits = next_token_logits.to(torch.float32) + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + token_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + data["input_ids"], + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"], seq_index=seq_index + ) + else: + next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) ref_logprobs = data["reference_policy_logprobs"][:, :-1] - diff = (next_token_logprobs - ref_logprobs) * token_mask + + diff = (token_logprobs - ref_logprobs) * token_mask rewards = diff.sum(-1) if self.preference_average_log_probs: @@ -883,13 +806,16 @@ def _dpo_loss( rewards, sample_mask, global_valid_seqs, self.reference_policy_kl_penalty ) - # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLossFn) + # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLoss) def __call__( # type: ignore self, - next_token_logprobs: Tensor, + next_token_logits: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -897,10 +823,13 @@ def __call__( # type: ignore "global_valid_toks must be provided for SFT loss" ) sft_loss, _ = self.sft_loss( - next_token_logprobs, + next_token_logits, data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) @@ -916,7 +845,14 @@ def __call__( # type: ignore accuracy, rewards_chosen_mean, rewards_rejected_mean, - ) = self._dpo_loss(next_token_logprobs, data, global_valid_seqs) + ) = self._dpo_loss( + next_token_logits, + data, + global_valid_seqs, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen @@ -937,6 +873,103 @@ def __call__( # type: ignore } +class SequencePackingLossWrapper: + def __init__( + self, + loss_fn: LossFunction, + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + ): + self.loss_fn = loss_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = cu_seqlens_q_padded + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" + unpadded_cu_seqlens = self.cu_seqlens_q + unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + if self.cu_seqlens_q_padded is not None: + padded_cu_seqlens = self.cu_seqlens_q_padded + padded_seq_lengths = ( + self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] + ) + else: + padded_cu_seqlens = unpadded_cu_seqlens + padded_seq_lengths = unpadded_seq_lengths + seq_starts = padded_cu_seqlens[:-1] + seq_ends = padded_cu_seqlens[1:] + + loss_accum = 0 + metrics_accum = {} + for seq_idx in range(len(seq_starts)): + seq_start = seq_starts[seq_idx].item() + seq_end = seq_ends[seq_idx].item() + + # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors + seq_data = data.slice(seq_idx, seq_idx + 1) + unpadded_seq_data = {} + for k, v in seq_data.items(): + if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: + unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] + else: + unpadded_seq_data[k] = v + + # get next_token_logits + cp_size = ( + 1 + if context_parallel_group is None + else torch.distributed.get_world_size(context_parallel_group) + ) + logit_start = seq_start // cp_size + logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size + logit_length = logit_end - logit_start + next_token_logits_slice = next_token_logits.narrow( + 1, logit_start, logit_length + ) + + loss, metrics = self.loss_fn( + next_token_logits_slice, + unpadded_seq_data, + global_valid_seqs, + global_valid_toks, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) + loss_accum += loss + for k, v in metrics.items(): + if k not in metrics_accum: + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + metrics_accum[k] = float("inf") + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + metrics_accum[k] = float("-inf") + else: + metrics_accum[k] = 0 + + val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v + + # Skip inf/-inf sentinel values (from sequences with no valid tokens) + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + if not math.isinf(val): + metrics_accum[k] = min(metrics_accum[k], val) + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + if not math.isinf(val): + metrics_accum[k] = max(metrics_accum[k], val) + else: + metrics_accum[k] += val + + return loss_accum, metrics_accum + + class DistillationLossConfig(TypedDict): kl_type: str mixed_kl_weight: float @@ -955,72 +988,345 @@ class DistillationLossDataDict(TypedDict): class DistillationLossFn(LossFunction): """Distillation loss function.""" - loss_type = LossType.TOKEN_LEVEL - input_type = LossInputType.DISTILLATION - def __init__(self, cfg: DistillationLossConfig): self.kl_type = cfg["kl_type"] - self.mixed_kl_weight = cfg["mixed_kl_weight"] + self.mixed_kl_weight = cfg.get("mixed_kl_weight", 0.5) self.zero_outside_topk = cfg["zero_outside_topk"] self.log_infinitesimal = -100 + self.loss_type = LossType.TOKEN_LEVEL assert self.kl_type in ["forward", "reverse", "mixed"], "Invalid KL type" - assert self.mixed_kl_weight >= 0 and self.mixed_kl_weight <= 1, ( + assert 0 <= self.mixed_kl_weight <= 1, ( "Invalid mixed KL weight" ) def __call__( self, - student_topk_logprobs: torch.Tensor, - teacher_topk_logprobs: torch.Tensor, - H_all: torch.Tensor | None, + next_token_logits: torch.Tensor, data: DistillationLossDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + teacher_logits: Optional = None, + mb_idx: Optional[int] = None, + mbs: Optional[int] = None, + teacher_topk_indices_ipc: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute distillation loss between teacher and student logits.""" - student_probs = student_topk_logprobs.exp() # [B, S-1, k] - teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k] - - loss_correction_term = torch.zeros_like(student_probs[..., 0]) # [B, S-1] - if self.zero_outside_topk and self.kl_type != "forward": - H_rest = H_all - (student_probs * student_topk_logprobs).sum(-1) - P_rest = 1 - (student_probs.sum(-1)) - # The entropy and prob of the rest of the tokens [B, S-1] - loss_correction_term = H_rest - self.log_infinitesimal * P_rest # [B, S-1] - if self.kl_type == "mixed": - loss_correction_term = loss_correction_term * ( - 1.0 - self.mixed_kl_weight + # Basic shapes + input_ids = data["input_ids"] + batch_size = input_ids.shape[0] + + # CP support: get CP group and size. + # Prefer the explicitly-passed group; fall back to the DTensor device + # mesh so the IPC path works even when the caller doesn't pass it. + cp_group = context_parallel_group + if cp_group is None and isinstance(next_token_logits, torch.distributed.tensor.DTensor): + mesh = next_token_logits.device_mesh + if mesh.mesh_dim_names is not None and "cp" in mesh.mesh_dim_names: + cp_group = mesh.get_group("cp") + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + + # Ensure float32 for stability (match other losses) + next_token_logits = next_token_logits.to(torch.float32) + per_token_kl = None + + # ===== IPC PATH: teacher logits passed as pre-reconstructed tensor ===== + if teacher_logits is not None and teacher_topk_indices_ipc is not None: + # Resolve TP-local student logits + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + device_mesh = next_token_logits.device_mesh + tp_group = device_mesh.get_group("tp") + tp_rank = tp_group.rank() + local_student_logits = next_token_logits.to_local() + V_local = int(local_student_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + else: + tp_group = None + tp_rank = 0 + local_student_logits = next_token_logits + V_local = int(local_student_logits.shape[-1]) + vocab_start_index = 0 + vocab_end_index = V_local + + with torch.no_grad(): + if mb_idx is not None and mbs is not None: + mb_start = mb_idx * mbs + mb_end = mb_start + mbs + teacher_topk_logprobs = teacher_logits[mb_start:mb_end, :, :].clone().detach() + topk_indices = teacher_topk_indices_ipc[mb_start:mb_end, :, :].clone().detach() + else: + teacher_topk_logprobs = teacher_logits.clone().detach() + topk_indices = teacher_topk_indices_ipc.clone().detach() + teacher_topk_logprobs = teacher_topk_logprobs.to(device=local_student_logits.device) + topk_indices = topk_indices.to(device=local_student_logits.device) + + # Gather student log probs at teacher's top-k global indices + if tp_group is not None: + S_local = int(local_student_logits.shape[1]) + chunk_size = max(1, min(S_local, 1024)) + student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( + local_student_logits, + topk_indices, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + False, + ) + else: + student_logprobs = torch.nn.functional.log_softmax( + local_student_logits, dim=-1 + ) + student_topk_logprobs = torch.gather( + student_logprobs, dim=-1, index=topk_indices + ) + del student_logprobs + del local_student_logits + + if self.kl_type == "reverse": + teacher_topk_logprobs, student_topk_logprobs = student_topk_logprobs, teacher_topk_logprobs + + # Build (k+1)-dim distributions with a "rest" bucket + teacher_topk_probs = teacher_topk_logprobs.exp() + teacher_rest_prob = (1.0 - teacher_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + teacher_probs_full = torch.cat([teacher_topk_probs, teacher_rest_prob], dim=-1) + teacher_logprobs_full = torch.cat([teacher_topk_logprobs, teacher_rest_prob.log()], dim=-1) + + student_topk_probs = student_topk_logprobs.exp() + student_rest_prob = (1.0 - student_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + student_logprobs_full = torch.cat([student_topk_logprobs, student_rest_prob.log()], dim=-1) + + per_token_kl = (teacher_probs_full * (teacher_logprobs_full - student_logprobs_full)).sum(dim=-1) + + del teacher_topk_logprobs, teacher_topk_probs, teacher_rest_prob + del teacher_probs_full, teacher_logprobs_full + del student_topk_logprobs, student_topk_probs, student_rest_prob + del student_logprobs_full, topk_indices + + # Next-token alignment + per_token_kl = per_token_kl[:, :-1] + + # ===== FULL-LOGPROB IPC PATH: teacher provides full vocab logprobs via IPC ===== + elif teacher_logits is not None: + # Resolve TP-local student logits + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + device_mesh = next_token_logits.device_mesh + tp_group = device_mesh.get_group("tp") + local_student_logits = next_token_logits.to_local() + else: + tp_group = None + local_student_logits = next_token_logits + + with torch.no_grad(): + if mb_idx is not None and mbs is not None: + mb_start_index = mb_idx * mbs + mb_end_index = mb_start_index + mbs + teacher_logprobs_local = teacher_logits[mb_start_index:mb_end_index, :, :].clone().detach() + else: + teacher_logprobs_local = teacher_logits.clone().detach() + teacher_logprobs_local = teacher_logprobs_local.to(device=local_student_logits.device) + + if tp_group is not None: + # Differentiable distributed log-softmax for student logits. + # The normalization constants are computed under no_grad, + # but the final log-softmax is built with differentiable ops + # so autograd can back-propagate through the student model. + with torch.no_grad(): + logits_max = torch.amax(local_student_logits, dim=-1, keepdim=True) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=tp_group + ) + shifted = local_student_logits - logits_max + local_sum_exp = shifted.exp().sum(-1, keepdim=True) + global_sum_exp = torch.distributed.nn.functional.all_reduce( + local_sum_exp, op=torch.distributed.ReduceOp.SUM, group=tp_group ) + student_logprobs_local = shifted - global_sum_exp.log() + del shifted, local_sum_exp, global_sum_exp + else: + student_logprobs_local = torch.nn.functional.log_softmax(local_student_logits, dim=-1) - if self.kl_type == "forward": - per_token_kl = teacher_probs * ( - teacher_topk_logprobs - student_topk_logprobs - ) - elif self.kl_type == "reverse": - per_token_kl = student_probs * ( - student_topk_logprobs - teacher_topk_logprobs - ) + per_token_kl = teacher_logprobs_local.exp() * (teacher_logprobs_local - student_logprobs_local) + per_token_kl = per_token_kl.sum(-1) + del teacher_logprobs_local, student_logprobs_local, local_student_logits + + if tp_group is not None: + per_token_kl = torch.distributed.nn.functional.all_reduce( + per_token_kl, op=torch.distributed.ReduceOp.SUM, group=tp_group + ) + + # Next-token alignment + per_token_kl = per_token_kl[:, :-1] + + # ===== STANDARD PATH: teacher top-k in data dict ===== else: - # mixed KL - kl_forward = teacher_probs * (teacher_topk_logprobs - student_topk_logprobs) - kl_reverse = student_probs * (student_topk_logprobs - teacher_topk_logprobs) - per_token_kl = ( - self.mixed_kl_weight * kl_forward - + (1.0 - self.mixed_kl_weight) * kl_reverse + teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] + teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k] + + if teacher_topk_indices.shape[-1] <= 0: + raise ValueError( + f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " + "topk=0 is not supported as it would result in empty tensor operations." + ) + + # Determine processing path and setup variables + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + V_local = int(next_token_logits.shape[-1]) + vocab_start_index = vocab_parallel_rank * V_local + vocab_end_index = (vocab_parallel_rank + 1) * V_local + parallel_group = vocab_parallel_group + logits_tensor = next_token_logits + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + device_mesh = next_token_logits.device_mesh + tp_group = device_mesh.get_group("tp") + tp_rank = tp_group.rank() + local_student_logits = next_token_logits.to_local() + V_local = int(local_student_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + parallel_group = tp_group + logits_tensor = local_student_logits + teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device) + if ( + device_mesh.mesh_dim_names is not None + and "cp" in device_mesh.mesh_dim_names + ): + cp_group = device_mesh.get_group("cp") + cp_size = cp_group.size() + else: + cp_group = None + cp_size = 1 + else: + parallel_group = None + logits_tensor = next_token_logits + + # Process based on zero_outside_topk setting + if self.zero_outside_topk and parallel_group is not None: + indices_local = teacher_topk_indices + pad_len = 0 + if cp_size > 1: + pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1] + if pad_len > 0: + indices_local = torch.nn.functional.pad( + indices_local, (0, 0, 0, pad_len), value=0 + ) + cp_rank = torch.distributed.get_rank(cp_group) + indices_local = _get_tokens_on_this_cp_rank( + indices_local, cp_rank, cp_size, seq_dim=1 + ) + + S_local = int(logits_tensor.shape[1]) + chunk_size = max(1, min(S_local, 1024)) + student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore + logits_tensor, + indices_local, + vocab_start_index, + vocab_end_index, + chunk_size, + parallel_group, + False, + ) + + if self.kl_type != "forward": + H_all = ChunkedDistributedEntropy.apply( # type: ignore + logits_tensor, + chunk_size, + parallel_group, + False, + ) + + if cp_size > 1: + student_topk_logprobs = allgather_cp_sharded_tensor( + student_topk_logprobs, cp_group, seq_dim=1 + ) + if self.kl_type != "forward": + H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) + if pad_len > 0: + student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] + if self.kl_type != "forward": + H_all = H_all[:, :-pad_len] + elif self.zero_outside_topk: + student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1) + student_topk_logprobs = student_logprobs.gather( + dim=-1, index=teacher_topk_indices.to(student_logprobs.device) + ) + if self.kl_type != "forward": + H_all = (student_logprobs.exp() * student_logprobs).sum(-1) + else: + if (parallel_group is not None) or (cp_size > 1): + student_topk_logits = gather_logits_at_global_indices( + logits_tensor, + teacher_topk_indices, + tp_group=parallel_group, + cp_group=cp_group, + vocab_start_index=( + vocab_start_index if parallel_group is not None else 0 + ), + vocab_end_index=( + vocab_end_index + if parallel_group is not None + else int(logits_tensor.shape[-1]) + ), + ) + else: + student_topk_logits = logits_tensor.gather( + dim=-1, index=teacher_topk_indices.to(logits_tensor.device) + ) + student_topk_logprobs = torch.nn.functional.log_softmax( + student_topk_logits, dim=-1 + ) + + teacher_topk_logits = teacher_topk_logits.to( + student_topk_logprobs.device, dtype=student_topk_logprobs.dtype ) - per_token_kl = per_token_kl.sum(dim=-1) + loss_correction_term # [B, S-1] + # Use the teacher's top-k values as log-probabilities directly + # (get_topk_logits returns log-probs when using DTensor/TP). + # Build (k+1)-dim distributions with a "rest" bucket to match + # the IPC path and preserve the true probability mass outside top-k. + teacher_topk_logprobs = teacher_topk_logits + + # Single point of next-token alignment after TP/CP processing + teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] + student_topk_logprobs = student_topk_logprobs[:, :-1, :] + + if self.kl_type == "reverse": + teacher_topk_logprobs, student_topk_logprobs = student_topk_logprobs, teacher_topk_logprobs + + teacher_topk_probs = teacher_topk_logprobs.exp() + teacher_rest_prob = (1.0 - teacher_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + teacher_probs_full = torch.cat([teacher_topk_probs, teacher_rest_prob], dim=-1) + teacher_logprobs_full = torch.cat([teacher_topk_logprobs, teacher_rest_prob.log()], dim=-1) + + student_topk_probs = student_topk_logprobs.exp() + student_rest_prob = (1.0 - student_topk_probs.sum(dim=-1, keepdim=True)).clamp(min=1e-10) + student_logprobs_full = torch.cat([student_topk_logprobs, student_rest_prob.log()], dim=-1) + + per_token_kl = (teacher_probs_full * (teacher_logprobs_full - student_logprobs_full)).sum(dim=-1) + + del teacher_topk_probs, teacher_rest_prob, teacher_probs_full, teacher_logprobs_full + del student_topk_probs, student_rest_prob, student_logprobs_full # Masking and reduction if "token_mask" in data and "sample_mask" in data: token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - # Align mask length to current per_token_kl max_len = per_token_kl.shape[1] - token_mask = token_mask[:, :max_len] - mask = token_mask * sample_mask.unsqueeze(-1) # [B, S-1] - # align mask shape to per_token_kl + if cp_size > 1: + cp_rank = torch.distributed.get_rank(cp_group) + S_local = max_len + 1 + start = cp_rank * S_local + token_mask = token_mask[:, start:start + max_len] + else: + token_mask = token_mask[:, :max_len] + mask = token_mask * sample_mask.unsqueeze(-1) kl_loss = masked_mean( per_token_kl, mask, @@ -1031,7 +1337,569 @@ def __call__( metrics = { "loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss, - "num_valid_samples": data["input_ids"].shape[0], + "num_valid_samples": int(batch_size), } return kl_loss, metrics + + +# ============================================================================= +# Cross-Tokenizer Distillation Loss (via TokenAligner) +# ============================================================================= + + +class CrossTokenizerDistillationLossConfig(TypedDict): + """Configuration for cross-tokenizer distillation loss.""" + loss_type: str # 'KL', 'cross_entropy', or 'chunked_ce' + temperature: float # Softmax temperature + vocab_topk: int # Reduce teacher vocab to top-k (0 = all) + exact_token_match_only: bool # Only use 1:1 aligned positions + reverse_kl: bool # Reverse KL direction + project_teacher_to_student: NotRequired[bool] + gold_loss: NotRequired[bool] # Use gold loss (common KL + uncommon L1, no projection) + xtoken_loss: NotRequired[bool] # Relaxed exact-map threshold (>=0.6 instead of ==1.0) + ce_loss_scale: NotRequired[float] # Scale for additional CE (next-token) loss (0.0 = disabled) + dynamic_loss_scaling: NotRequired[bool] # Scale KL loss to match CE magnitude + + +class CrossTokenizerDistillationLossDataDict(TypedDict): + """Data dict for cross-tokenizer distillation. + + Only contains student-side tensors (same sequence dimension). + Teacher-side data (teacher_input_ids, aligned_pairs) is stored on the + loss function instance via set_cross_tokenizer_data() to avoid + sequence-length mismatches in the worker's shape validation. + """ + input_ids: torch.Tensor # Student token IDs (B, S_student) + input_lengths: torch.Tensor + token_mask: torch.Tensor # (B, S_student) + sample_mask: torch.Tensor # (B,) + + +class CrossTokenizerDistillationLossFn(LossFunction): + """Cross-tokenizer distillation loss using TokenAligner's projection matrix. + + Computes per-token KL divergence between projected student probabilities + (in teacher vocab space) and teacher probabilities, only at positions where + the two tokenizations have 1:1 aligned tokens. Uses NeMo RL's standard + masked_mean normalization so loss magnitude is comparable to same-tokenizer + distillation. + + Teacher-specific data (teacher_input_ids, aligned_pairs) is stored on + this object via set_cross_tokenizer_data() before each training step, + rather than in the data dict, because teacher and student sequences + have different lengths and the worker validates that all tensors in + the data dict share the same sequence dimension. + """ + + def __init__(self, cfg: CrossTokenizerDistillationLossConfig, token_aligner): + from nemo_rl.algorithms.x_token import TokenAligner + assert isinstance(token_aligner, TokenAligner) + self.token_aligner = token_aligner + self.cfg = cfg + self.loss_type = LossType.TOKEN_LEVEL + self._teacher_input_ids = None + self._aligned_pairs = None + + def set_cross_tokenizer_data( + self, + teacher_input_ids: torch.Tensor, + aligned_pairs: list, + ): + """Store teacher-side data before each training step. + + Called from the training loop before student_policy.train(). + The worker never sees these tensors in shape validation. + """ + self._teacher_input_ids = teacher_input_ids + self._aligned_pairs = aligned_pairs + + def _project_student_to_teacher( + self, + student_logits: torch.Tensor, + teacher_vocab_size: int, + temperature: float, + global_top_indices: torch.Tensor, + device: torch.device, + ) -> torch.Tensor: + """Project student logits into the reduced teacher vocabulary space. + + Returns projected student probabilities of shape (B, S_student, K) + where K = len(global_top_indices). + """ + student_probs = torch.softmax(student_logits / temperature, dim=-1) + + has_sparse = ( + hasattr(self.token_aligner, 'sparse_transformation_matrix') + and self.token_aligner.sparse_transformation_matrix is not None + ) + if has_sparse: + sparse_mat = self.token_aligner.sparse_transformation_matrix + reduced_sparse = sparse_mat.index_select(1, global_top_indices).coalesce() + projected = self.token_aligner.project_token_likelihoods_instance( + student_probs, None, None, None, device, + use_sparse_format=True, + sparse_matrix=reduced_sparse, + ) + return projected + + proj_values = self.token_aligner.likelihood_projection_matrix + if getattr(self.token_aligner, 'learnable', False): + proj_values = self.token_aligner.transform_learned_matrix_instance(proj_values) + projected_full = self.token_aligner.project_token_likelihoods_instance( + student_probs, self.token_aligner.likelihood_projection_indices, + proj_values, teacher_vocab_size, device, + use_sparse_format=False, + ) + return projected_full[:, :, global_top_indices] + + def _compute_gold_loss( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + aligned_pairs: list, + batch_size: int, + student_seq_len: int, + teacher_seq_len: int, + student_vocab_size: int, + teacher_vocab_size: int, + temperature: float, + reverse_kl: bool, + xtoken_loss: bool, + device: torch.device, + ) -> tuple[torch.Tensor, float]: + """Gold loss: common-vocab KL + uncommon-vocab sorted L1. + + Splits the vocabulary into tokens with exact 1:1 projection mappings + ("common") and the rest ("uncommon"). Common tokens are compared + directly via KL on their native log-probs (no projection needed). + Uncommon tokens are compared via L1 on sorted probability vectors + (Universal Likelihood Distillation). + + Matches tokenalign.py compute_KL_loss_optimized gold_loss branch. + """ + aligner = self.token_aligner + if not hasattr(aligner, 'likelihood_projection_indices') or aligner.likelihood_projection_indices is None: + raise ValueError("gold_loss requires likelihood_projection_indices to be loaded") + + projection_indices = aligner.likelihood_projection_indices + projection_matrix = ( + aligner.transform_learned_matrix_instance(aligner.likelihood_projection_matrix) + if getattr(aligner, 'learnable', False) + else aligner.likelihood_projection_matrix + ) + + sorted_values, sorted_indices_in_topk = torch.sort(projection_matrix, dim=-1, descending=True) + + if xtoken_loss: + has_exact_map = (sorted_values[:, 0] >= 0.6) + else: + has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1) + + student_indices_with_exact_map = torch.where(has_exact_map)[0] + teacher_indices_for_exact_map = projection_indices[ + student_indices_with_exact_map, + sorted_indices_in_topk[student_indices_with_exact_map, 0], + ] + + student_to_teacher_exact_map: dict[int, int] = {} + teacher_to_student_exact_map: dict[int, int] = {} + for s_idx, t_idx in zip( + student_indices_with_exact_map.tolist(), + teacher_indices_for_exact_map.tolist(), + ): + if 0 <= t_idx < teacher_vocab_size: + if t_idx not in teacher_to_student_exact_map or xtoken_loss: + if t_idx in teacher_to_student_exact_map: + prev_student_token = teacher_to_student_exact_map[t_idx] + if sorted_values[prev_student_token, 0] >= sorted_values[s_idx, 0]: + continue + del student_to_teacher_exact_map[prev_student_token] + student_to_teacher_exact_map[s_idx] = t_idx + teacher_to_student_exact_map[t_idx] = s_idx + + common_student_indices = sorted(student_to_teacher_exact_map.keys()) + common_teacher_indices = [student_to_teacher_exact_map[s] for s in common_student_indices] + uncommon_student_indices = sorted(set(range(student_vocab_size)) - set(common_student_indices)) + uncommon_teacher_indices = sorted(set(range(teacher_vocab_size)) - set(common_teacher_indices)) + + # Build chunk masks from alignment pairs (matching tokenalign.py exactly) + max_n_chunks = min(student_seq_len, teacher_seq_len) + + student_chunk_mask = torch.zeros( + (batch_size, student_seq_len, max_n_chunks), dtype=torch.bool, device=device, + ) + teacher_chunk_mask = torch.zeros( + (batch_size, teacher_seq_len, max_n_chunks), dtype=torch.bool, device=device, + ) + + for batch_idx in range(batch_size): + for chunk_idx, alignment_pair in enumerate(aligned_pairs[batch_idx][:max_n_chunks]): + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1: + student_chunk_mask[batch_idx, start1:end1, chunk_idx] = True + teacher_chunk_mask[batch_idx, start2:end2, chunk_idx] = True + + # log_softmax on full original logits BEFORE chunk averaging + student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits / temperature, dim=-1) + + # Chunk-average log-probs over full vocabularies + student_chunk_lp = torch.bmm( + student_chunk_mask.transpose(1, 2).to(student_log_probs.dtype), student_log_probs, + ) + teacher_chunk_lp = torch.bmm( + teacher_chunk_mask.transpose(1, 2).to(teacher_log_probs.dtype), teacher_log_probs, + ) + del student_log_probs, teacher_log_probs + + student_chunk_sizes = student_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) + teacher_chunk_sizes = teacher_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) + + student_chunk_lp = student_chunk_lp / (student_chunk_sizes + 1e-10) + teacher_chunk_lp = teacher_chunk_lp / (teacher_chunk_sizes + 1e-10) + + chunk_valid = (student_chunk_sizes.squeeze(-1) > 0) & (teacher_chunk_sizes.squeeze(-1) > 0) + + if not chunk_valid.any(): + return torch.tensor(0.0, device=device, requires_grad=True), 0.0 + + # --- Part 1: KL on common (exactly-mapped) vocab --- + loss_kl_common = torch.tensor(0.0, device=device, requires_grad=True) + if len(common_student_indices) > 0: + cs = torch.tensor(common_student_indices, device=device) + ct = torch.tensor(common_teacher_indices, device=device) + s_common = student_chunk_lp[:, :, cs] + t_common = teacher_chunk_lp[:, :, ct] + + if not reverse_kl: + kl_elem = torch.nn.functional.kl_div( + s_common, t_common, reduction="none", log_target=True, + ) + else: + kl_elem = torch.nn.functional.kl_div( + t_common, s_common, reduction="none", log_target=True, + ) + kl_per_chunk = kl_elem.sum(dim=-1) * chunk_valid + if chunk_valid.sum() > 0: + loss_kl_common = kl_per_chunk.sum() / chunk_valid.sum() + + # --- Part 2: L1 on uncommon (unaligned) vocab --- + loss_l1_uncommon = torch.tensor(0.0, device=device, requires_grad=True) + if len(uncommon_student_indices) > 0 or len(uncommon_teacher_indices) > 0: + s_uncommon = student_chunk_lp[:, :, torch.tensor(uncommon_student_indices, device=device)] if uncommon_student_indices else torch.empty(batch_size, max_n_chunks, 0, device=device) + t_uncommon = teacher_chunk_lp[:, :, torch.tensor(uncommon_teacher_indices, device=device)] if uncommon_teacher_indices else torch.empty(batch_size, max_n_chunks, 0, device=device) + + s_valid = s_uncommon[chunk_valid] + t_valid = t_uncommon[chunk_valid] + + if s_valid.shape[0] > 0: + with torch.no_grad(): + max_uncommon_vocab = min(s_valid.shape[-1], t_valid.shape[-1], 8192) + + if max_uncommon_vocab > 0: + s_probs = torch.exp(s_valid) + t_probs = torch.exp(t_valid) + + if s_probs.shape[-1] > max_uncommon_vocab: + s_sorted, _ = torch.topk(s_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + s_sorted = torch.sort(s_probs, dim=-1, descending=True)[0] + + if t_probs.shape[-1] > max_uncommon_vocab: + t_sorted, _ = torch.topk(t_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + t_sorted = torch.sort(t_probs, dim=-1, descending=True)[0] + + del s_probs, t_probs + min_len = min(s_sorted.shape[-1], t_sorted.shape[-1]) + if min_len > 0: + loss_l1_per_chunk = torch.nn.functional.l1_loss( + s_sorted[:, :min_len], t_sorted[:, :min_len], reduction='none', + ).sum(dim=-1) + loss_l1_uncommon = loss_l1_per_chunk.mean() + del loss_l1_per_chunk + del s_sorted, t_sorted + + loss_total = (loss_kl_common + loss_l1_uncommon) * (temperature ** 2) + + # Top-1 accuracy on common vocab + top1_accuracy = 0.0 + with torch.no_grad(): + if len(common_student_indices) > 0 and chunk_valid.any(): + cs = torch.tensor(common_student_indices, device=device) + ct = torch.tensor(common_teacher_indices, device=device) + s_valid_lp = student_chunk_lp[chunk_valid][:, cs] + t_valid_lp = teacher_chunk_lp[chunk_valid][:, ct] + matches = (s_valid_lp.argmax(dim=-1) == t_valid_lp.argmax(dim=-1)).sum().item() + top1_accuracy = matches / chunk_valid.sum().item() + + del student_chunk_lp, teacher_chunk_lp, student_chunk_mask, teacher_chunk_mask + return loss_total, top1_accuracy + + def __call__( + self, + next_token_logits: torch.Tensor, + data: CrossTokenizerDistillationLossDataDict, + global_valid_seqs: torch.Tensor, + global_valid_toks: torch.Tensor, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + teacher_logits: Optional[torch.Tensor] = None, + mb_idx: Optional[int] = None, + mbs: Optional[int] = None, + teacher_topk_indices_ipc: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute cross-tokenizer distillation loss via chunk-averaged KL. + + For each alignment chunk (1:1, 1:many, many:1, or many:many), the + projected student and teacher distributions are averaged over their + respective spans, renormalized, and compared via KL divergence. + The per-chunk KL is then distributed back to student positions + and normalized with the standard NeMo RL masked_mean. + """ + input_ids_student = data["input_ids"] + batch_size = input_ids_student.shape[0] + + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + student_logits = next_token_logits.full_tensor().to(torch.float32) + else: + student_logits = next_token_logits.to(torch.float32) + + if teacher_logits is None: + raise ValueError( + "CrossTokenizerDistillationLossFn requires teacher_logits via IPC. " + "Set use_ipc=True in the distillation config." + ) + if self._aligned_pairs is None or self._teacher_input_ids is None: + raise ValueError( + "Cross-tokenizer data not set. " + "Call loss_fn.set_cross_tokenizer_data() before training." + ) + + if isinstance(teacher_logits, torch.distributed.tensor.DTensor): + teacher_logits_f32 = teacher_logits.full_tensor().to(torch.float32) + else: + teacher_logits_f32 = teacher_logits.to(torch.float32) + + if teacher_logits_f32.shape[-1] == 0: + raise ValueError( + f"Teacher logits have vocab dimension 0 (shape={teacher_logits_f32.shape}). " + "This typically means topk_logits=0 was passed instead of None " + "for the teacher forward pass. Cross-tokenizer distillation " + "requires full teacher logits (topk_logits=None)." + ) + + aligned_pairs = self._aligned_pairs + if mb_idx is not None and mbs is not None: + mb_start = mb_idx * mbs + mb_end = mb_start + batch_size + aligned_pairs = aligned_pairs[mb_start:mb_end] + + self.token_aligner = self.token_aligner.to(student_logits.device) + device = student_logits.device + + temperature = self.cfg.get("temperature", 1.0) + vocab_topk = self.cfg.get("vocab_topk", 8192) + reverse_kl = self.cfg.get("reverse_kl", False) + exact_match_only = self.cfg.get("exact_token_match_only", False) + use_gold_loss = self.cfg.get("gold_loss", False) + use_xtoken_loss = self.cfg.get("xtoken_loss", False) + student_seq_len = student_logits.shape[1] + teacher_seq_len = teacher_logits_f32.shape[1] + student_vocab_size = student_logits.shape[-1] + teacher_vocab_size = teacher_logits_f32.shape[-1] + + # -- 1. Filter alignment pairs and count chunks -- + filtered_pairs: list[list[tuple]] = [] + total_chunks = 0 + for batch_idx in range(batch_size): + batch_pairs = [] + for pair in aligned_pairs[batch_idx]: + s1text, s2text, s1_start, s1_end, s2_start, s2_end = pair[:6] + if exact_match_only and (s1_end - s1_start != 1 or s2_end - s2_start != 1): + continue + if s1_start == -1 or s2_start == -1: + continue + if s1_end > student_seq_len or s2_end > teacher_seq_len: + continue + batch_pairs.append(pair) + filtered_pairs.append(batch_pairs) + total_chunks = max(total_chunks, len(batch_pairs)) + + if total_chunks == 0: + loss = torch.tensor(0.0, device=device, requires_grad=True) + return loss, {"loss": 0.0, "topk_accuracy": 0.0, "num_chunks": 0} + + # -- 2. Build chunk masks (B, seq_len, num_chunks) -- + proj_mask = torch.zeros( + batch_size, student_seq_len, total_chunks, dtype=torch.bool, device=device, + ) + tgt_mask = torch.zeros( + batch_size, teacher_seq_len, total_chunks, dtype=torch.bool, device=device, + ) + for batch_idx in range(batch_size): + for chunk_idx, pair in enumerate(filtered_pairs[batch_idx]): + _, _, s1_start, s1_end, s2_start, s2_end = pair[:6] + proj_mask[batch_idx, s1_start:s1_end, chunk_idx] = True + tgt_mask[batch_idx, s2_start:s2_end, chunk_idx] = True + + # ================================================================ + # Gold loss path: common-vocab KL + uncommon-vocab sorted L1. + # Bypasses the projection matrix for tokens with exact 1:1 mappings. + # Matches tokenalign.py compute_KL_loss_optimized gold_loss branch. + # ================================================================ + if use_gold_loss: + loss, top1_accuracy = self._compute_gold_loss( + student_logits, teacher_logits_f32, aligned_pairs, + batch_size, student_seq_len, teacher_seq_len, + student_vocab_size, teacher_vocab_size, + temperature, reverse_kl, use_xtoken_loss, device, + ) + else: + # ================================================================ + # Standard projection-based path + # ================================================================ + + # -- 3. Global vocabulary filtering (top-k teacher tokens) -- + with torch.no_grad(): + if vocab_topk == 0 or vocab_topk >= teacher_vocab_size: + global_top_indices = torch.arange(teacher_vocab_size, device=device) + else: + teacher_flat = teacher_logits_f32.view(-1, teacher_vocab_size) + importance = teacher_flat.max(dim=0)[0] + _, global_top_indices = torch.topk( + importance, k=min(vocab_topk, teacher_vocab_size), dim=-1, + ) + global_top_indices = global_top_indices.sort()[0] + + # -- 4. Project student probs to teacher vocab -- + projected_student = self._project_student_to_teacher( + student_logits, teacher_vocab_size, temperature, global_top_indices, device, + ) + + # -- 5. Teacher log-probs in reduced vocab -- + teacher_logits_reduced = teacher_logits_f32[:, :, global_top_indices] + teacher_log_probs = torch.log_softmax(teacher_logits_reduced / temperature, dim=-1) + del teacher_logits_reduced + + # -- 6. Chunk-averaged distributions -- + proj_chunks = torch.bmm( + proj_mask.transpose(1, 2).to(projected_student.dtype), projected_student, + ) + tgt_log_chunks = torch.bmm( + tgt_mask.transpose(1, 2).to(teacher_log_probs.dtype), teacher_log_probs, + ) + del projected_student, teacher_log_probs + + proj_sizes = proj_mask.sum(dim=1).unsqueeze(-1).to(proj_chunks.dtype) + tgt_sizes = tgt_mask.sum(dim=1).unsqueeze(-1).to(tgt_log_chunks.dtype) + + proj_chunks = proj_chunks / (proj_sizes + 1e-10) + tgt_log_chunks = tgt_log_chunks / (tgt_sizes + 1e-10) + + proj_chunks = proj_chunks / (proj_chunks.sum(dim=-1, keepdim=True) + 1e-10) + proj_log_chunks = torch.log(proj_chunks + 1e-10) + + chunk_valid = (proj_sizes.squeeze(-1) > 0) & (tgt_sizes.squeeze(-1) > 0) + + # -- 7. KL divergence per chunk -- + if reverse_kl: + kl_per_elem = torch.nn.functional.kl_div( + tgt_log_chunks, proj_log_chunks, reduction="none", log_target=True, + ) + else: + kl_per_elem = torch.nn.functional.kl_div( + proj_log_chunks, tgt_log_chunks, reduction="none", log_target=True, + ) + kl_per_chunk = kl_per_elem.sum(dim=-1) * (temperature ** 2) + kl_per_chunk = kl_per_chunk * chunk_valid + del proj_chunks, tgt_log_chunks, proj_log_chunks, kl_per_elem + + # -- 8. Scalar loss -- + num_valid_chunks = chunk_valid.sum() + if num_valid_chunks > 0: + loss = kl_per_chunk.sum() / num_valid_chunks + else: + loss = torch.tensor(0.0, device=device, requires_grad=True) + top1_accuracy = 0.0 + + # ================================================================ + # Optional CE (next-token prediction) loss, matching the DDP + # train_distillation_ddp.py logic: + # without dynamic scaling: loss = kl * kl_weight + ce * ce_scale + # with dynamic scaling: loss = kl * (ce/kl) + ce + # ================================================================ + kl_loss = loss + ce_loss_scale = self.cfg.get("ce_loss_scale", 0.0) + dynamic_loss_scaling = self.cfg.get("dynamic_loss_scaling", False) + ce_loss_value = 0.0 + + if ce_loss_scale > 0.0 or dynamic_loss_scaling: + ce_loss = torch.nn.functional.cross_entropy( + student_logits[:, :-1].reshape(-1, student_logits.shape[-1]), + input_ids_student[:, 1:].reshape(-1), + ignore_index=-100, + ) + ce_loss_value = float(ce_loss.item()) + + if dynamic_loss_scaling and kl_loss.item() > 0: + dls_scale = ce_loss.item() / kl_loss.item() + loss = kl_loss * dls_scale + ce_loss + else: + loss = kl_loss + ce_loss * ce_loss_scale + + # One-time debug dump for sanity check comparison with standalone TokenAligner + if not getattr(self, '_debug_dumped', False): + self._debug_dumped = True + raw_loss = float(loss.item()) if loss.ndim == 0 else float(loss) + print(f"[CrossTokenKL DEBUG] raw_chunk_loss={raw_loss:.6f}, " + f"gold_loss={use_gold_loss}, " + f"student_shape={student_logits.shape}, " + f"teacher_shape={teacher_logits_f32.shape}, " + f"total_filtered_pairs={sum(len(fp) for fp in filtered_pairs)}", flush=True) + try: + import os + dump_dir = os.environ.get("CROSS_TOK_DEBUG_DIR", "/tmp/cross_tok_debug") + os.makedirs(dump_dir, exist_ok=True) + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + teacher_ids = self._teacher_input_ids + if mb_idx is not None and mbs is not None: + teacher_ids = teacher_ids[mb_idx * mbs : mb_idx * mbs + batch_size] + torch.save({ + "student_logits": student_logits.cpu(), + "teacher_logits": teacher_logits_f32.cpu(), + "input_ids_student": input_ids_student.cpu(), + "input_ids_teacher": teacher_ids.cpu(), + "aligned_pairs": aligned_pairs, + "config": dict(self.cfg), + }, os.path.join(dump_dir, f"debug_rank{rank}.pt")) + print(f"[CrossTokenKL DEBUG] Saved debug tensors to {dump_dir}/debug_rank{rank}.pt", flush=True) + except Exception as e: + print(f"[CrossTokenKL DEBUG] Failed to save debug tensors: {e}", flush=True) + + # Scale for NeMo RL distributed training + token_mask = data["token_mask"] + sample_mask = data["sample_mask"] + max_len = min(token_mask.shape[1] - 1, student_seq_len) + local_mask = token_mask[:, 1 : max_len + 1] * sample_mask.unsqueeze(-1) + local_valid_toks = local_mask.sum() + + if local_valid_toks > 0 and global_valid_toks > 0: + loss = loss * local_valid_toks / global_valid_toks + else: + loss = loss * 0.0 + + num_valid = sum(len(fp) for fp in filtered_pairs) + metrics = { + "loss": float(loss.item()) if loss.ndim == 0 else loss, + "kl_loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss, + "ce_loss": ce_loss_value, + "topk_accuracy": top1_accuracy, + "num_valid_samples": int(batch_size), + "num_chunks": num_valid, + "alignment_density": num_valid / max(1, batch_size * student_seq_len), + } + + return loss, metrics diff --git a/nemo_rl/algorithms/off_policy_distillation.py b/nemo_rl/algorithms/off_policy_distillation.py new file mode 100644 index 0000000000..e642ea046b --- /dev/null +++ b/nemo_rl/algorithms/off_policy_distillation.py @@ -0,0 +1,1578 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and limitations. +# limitations under the License. + +""" +Off-Policy Distillation Algorithm + +This module implements off-policy distillation where: +- A fixed dataset of prompt-response pairs is used (no student generation) +- Teacher provides logits for the fixed responses +- Student aligns with teacher using KL divergence loss + +Key difference from on-policy distillation (in distillation.py): +- No student generation step - uses pre-existing responses from dataset +- No environment needed for reward computation +- Simpler training loop without rollout generation +""" + +import math +import multiprocessing +import os +import warnings +import importlib.util +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from pathlib import Path +import sys +if sys.version_info >= (3, 11): + from typing import Any, Callable, NotRequired, Optional, TypedDict, TypeVar, cast +else: + from typing import Any, Callable, Optional, TypedDict, TypeVar, cast + from typing_extensions import NotRequired + +import numpy as np +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoConfig, AutoTokenizer +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + +from nemo_rl.algorithms.loss_functions import ( + CrossTokenizerDistillationLossFn, + DistillationLossConfig, + DistillationLossDataDict, + DistillationLossFn, +) +from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed +from nemo_rl.data import DataConfig +from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ( + ClusterConfig, + RayVirtualCluster, +) +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface +from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager +from nemo_rl.utils.logger import Logger, LoggerConfig +from nemo_rl.utils.nsys import maybe_gpu_profile_step +from nemo_rl.utils.timer import TimeoutChecker, Timer + +# =============================================================================== +# Configuration +# =============================================================================== +TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) + + +class TokenAlignerConfig(TypedDict, total=False): + """Configuration for cross-tokenizer distillation via TokenAligner. + + When enabled, teacher and student may use different tokenizers/vocabularies. + A precomputed projection matrix maps between the two vocabulary spaces. + """ + enabled: bool # Master switch for cross-tokenizer mode + projection_matrix_path: str # Path to .pt projection matrix file + use_sparse_format: bool # True = sparse COO format, False = dense indices/values + loss_type: str # 'KL', 'cross_entropy', or 'chunked_ce' + exact_token_match_only: bool # Only use 1:1 aligned token positions for loss + temperature: float # Softmax temperature for KL computation + vocab_topk: int # Reduce teacher vocab to top-k for speed (0 = all) + reverse_kl: bool # If True, use reverse KL direction + projection_matrix_multiplier: float # Scaling factor for projection matrix + max_comb_len: int # Max combination length for token alignment DP + learnable: bool # If True, projection matrix is trainable + project_teacher_to_student: bool # If True, project teacher->student instead of student->teacher + use_char_offset: bool # If True, try char-offset alignment before DP fallback + force_dp_only: bool # If True, disable char-offset path and run DP for all samples + use_cuda_dp: bool # If True, patch TokenAligner chunked DP base case with CUDA kernel + dp_chunk_size: int # Chunk size used by DP chunked solver + use_align_fast: bool # If True, use align_fast for DP path; default False for parity + + +class OffPolicyDistillationConfig(TypedDict): + """Configuration for off-policy distillation training. + + Simplified compared to on-policy: + - No num_generations_per_prompt (we use fixed responses) + - No max_rollout_turns (no generation) + """ + num_prompts_per_step: int # Batch size + max_num_steps: int # Maximum number of steps to train for + max_num_epochs: int # Maximum number of epochs to train for + topk_logits_k: int # Top-k logits for sparse KL loss + seed: int + # Validation settings + val_period: NotRequired[int] # Run validation every N steps (0 = disabled) + val_batches: NotRequired[int] # Number of validation batches (0 = all) + val_global_batch_size: NotRequired[int] # Validation batch size + val_micro_batch_size: NotRequired[int] # Validation micro batch size + val_at_start: NotRequired[bool] # Run validation before training starts + # CPU processes for parallel cross-tokenizer decode/encode/align (None = auto, 1 = sequential) + cross_tokenizer_num_workers: NotRequired[Optional[int]] + + +class OffPolicyDistillationSaveState(TypedDict): + """State to save for checkpointing.""" + total_steps: int # Track total number of steps across all epochs + current_epoch: int # Track current epoch + current_step: int # Track step within current epoch + consumed_samples: int + total_valid_tokens: int # Track total number of non-padding tokens during training + + +def _default_distillation_save_state() -> OffPolicyDistillationSaveState: + return { + "current_epoch": 0, + "current_step": 0, + "total_steps": 0, + "consumed_samples": 0, + "total_valid_tokens": 0, + } + + +class OffPolicyMasterConfig(TypedDict): + """Main configuration structure for off-policy distillation. + + Key difference from on-policy MasterConfig: + - No 'env' config (no environment needed) + """ + policy: PolicyConfig # Student model configuration + teacher: PolicyConfig # Teacher model configuration + loss_fn: DistillationLossConfig # Loss function configuration + data: DataConfig # Data configuration + distillation: OffPolicyDistillationConfig # Distillation configuration + logger: LoggerConfig # Logger configuration + cluster: ClusterConfig # Cluster configuration + checkpointing: CheckpointingConfig # Checkpointing configuration + token_aligner: NotRequired[TokenAlignerConfig] # Cross-tokenizer config (optional) + + +class _PrefetchedBatchPack(TypedDict): + batch: BatchedDataDict[DatumSpec] + flat_messages: BatchedDataDict[DatumSpec] + input_lengths: torch.Tensor + train_data: BatchedDataDict[DistillationLossDataDict] + ct_future: Any + + +# =============================================================================== +# Cross-Tokenizer Parallel Processing +# =============================================================================== + +# Module-level global set by _init_align_worker for each pool process. +_ct_token_aligner = None +_ct_dp_chunk_size = 128 +_ct_use_align_fast = False + + +def _init_align_worker(token_aligner, dp_chunk_size: int, use_align_fast: bool): + """Initializer for ProcessPoolExecutor workers. + + Stores the TokenAligner once per process to avoid re-pickling every call. + """ + global _ct_token_aligner, _ct_dp_chunk_size, _ct_use_align_fast + _ct_token_aligner = token_aligner + _ct_dp_chunk_size = int(dp_chunk_size) + _ct_use_align_fast = bool(use_align_fast) + if getattr(_ct_token_aligner, "_use_cuda_dp", False): + cuda_dp_path_str = getattr(_ct_token_aligner, "_cuda_dp_module_path", "") + if cuda_dp_path_str: + spec = importlib.util.spec_from_file_location("x_token_cuda_dp_worker", cuda_dp_path_str) + if spec is not None and spec.loader is not None: + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + mod.monkeypatch_tokenaligner_cuda_basecase() + + +def _align_chunk(args): + """Align a chunk of (student, teacher) token-ID pairs. + + Called by ProcessPoolExecutor. Uses align_fast only when enabled; + otherwise falls back to regular align. + + Args: + args: (student_ids_chunk, teacher_ids_chunk) — both list[list[int]]. + + Returns: + aligned_pairs from ``TokenAligner.align_fast`` or ``TokenAligner.align``. + """ + student_ids_chunk, teacher_ids_chunk = args + student_t = torch.tensor(student_ids_chunk) + teacher_t = torch.tensor(teacher_ids_chunk) + if _ct_use_align_fast and _ct_token_aligner._student_canon_map is not None: + return _ct_token_aligner.align_fast( + student_t, teacher_t, chunk_size=_ct_dp_chunk_size + ) + return _ct_token_aligner.align( + student_t, teacher_t, chunk_size=_ct_dp_chunk_size + ) + + +def _align_by_char_offsets( + s_content: list[tuple[int, int, int]], + t_content: list[tuple[int, int, int]], +) -> list[tuple]: + """Align tokens via character offsets in O(n+m). + + Both sequences are tokenizations of the same text, so their character + spans partition the same string. A two-pointer walk groups tokens + whose character boundaries converge. + + Args: + s_content: [(char_start, char_end, token_position), ...] for student, + sorted by char_start, excluding special/pad tokens. + t_content: same for teacher. + + Returns: + aligned_pairs in the standard 7-tuple format: + (s1_tokens, s2_tokens, s_pos_start, s_pos_end, t_pos_start, t_pos_end, is_correct) + """ + pairs: list[tuple] = [] + si, ti = 0, 0 + n_s, n_t = len(s_content), len(t_content) + + while si < n_s and ti < n_t: + s_group_start = si + t_group_start = ti + s_char_end = s_content[si][1] + t_char_end = t_content[ti][1] + + while s_char_end != t_char_end: + if s_char_end < t_char_end: + si += 1 + if si >= n_s: + break + s_char_end = s_content[si][1] + else: + ti += 1 + if ti >= n_t: + break + t_char_end = t_content[ti][1] + + if s_char_end != t_char_end: + break + + pairs.append(( + [], [], + s_content[s_group_start][2], s_content[si][2] + 1, + t_content[t_group_start][2], t_content[ti][2] + 1, + True, + )) + si += 1 + ti += 1 + + for i in range(si, n_s): + pos = s_content[i][2] + pairs.append(([], [], pos, pos + 1, -1, -1, False)) + for i in range(ti, n_t): + pos = t_content[i][2] + pairs.append(([], [], -1, -1, pos, pos + 1, False)) + + return pairs + + +def _process_cross_tokenizer_batch( + train_input_ids: torch.Tensor, + batch_loss_multiplier: torch.Tensor, + extra_env: Any, + tokenizer: PreTrainedTokenizerBase, + teacher_tokenizer: PreTrainedTokenizerBase, + token_aligner: Any, + use_char_offset: bool, + use_align_fast: bool, + dp_chunk_size: int, + ct_pool: Optional[ProcessPoolExecutor], + max_teacher_len_rt: int, +) -> tuple[torch.Tensor, list[Any], BatchedDataDict]: + """Prepare teacher inputs + aligned pairs for one training batch.""" + import time as _time + + student_ids = train_input_ids + batch_size_ct = student_ids.shape[0] + + _t0 = _time.time() + has_raw_text = ( + extra_env + and len(extra_env) == batch_size_ct + and all( + info is not None + and isinstance(info, dict) + and "raw_text" in info + for info in extra_env + ) + ) + if has_raw_text: + texts = [info["raw_text"] for info in extra_env] + else: + # Fallback only when raw text is unavailable for the batch. + texts = tokenizer.batch_decode(student_ids.tolist(), skip_special_tokens=True) + _t1 = _time.time() + + teacher_encoded = teacher_tokenizer( + texts, + max_length=max_teacher_len_rt, + padding="max_length", + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + ) + teacher_input_ids = teacher_encoded["input_ids"] + teacher_attention_mask = teacher_encoded["attention_mask"] + teacher_offsets = teacher_encoded["offset_mapping"] + + student_re = tokenizer( + texts, + max_length=student_ids.shape[1], + padding="max_length", + truncation=True, + return_tensors="pt", + return_offsets_mapping=True, + ) + # Align against student/teacher IDs tokenized from the same raw text. + # This keeps alignment semantics symmetric across tokenizers. + student_align_ids = student_re["input_ids"] + student_offsets = student_re["offset_mapping"] + _t2 = _time.time() + + # --- Vectorized pre-check: which samples can try char-offset? --- + s_off_np = student_offsets.numpy() + t_off_np = teacher_offsets.numpy() + + s_nonzero = (s_off_np[:, :, 0] != 0) | (s_off_np[:, :, 1] != 0) + t_nonzero = (t_off_np[:, :, 0] != 0) | (t_off_np[:, :, 1] != 0) + + s_has = s_nonzero.any(axis=1) + t_has = t_nonzero.any(axis=1) + + s_last = s_nonzero.shape[1] - 1 - np.flip(s_nonzero, axis=1).argmax(axis=1) + t_last = t_nonzero.shape[1] - 1 - np.flip(t_nonzero, axis=1).argmax(axis=1) + s_last_end = s_off_np[np.arange(batch_size_ct), s_last, 1] + t_last_end = t_off_np[np.arange(batch_size_ct), t_last, 1] + + if not use_char_offset: + can_try_offset = np.zeros(batch_size_ct, dtype=bool) + else: + can_try_offset = s_has & t_has & (s_last_end == t_last_end) + + # --- Vectorized offset filtering (avoid per-sample .tolist()) --- + # Pre-extract content indices per sample using numpy + s_content_per_sample = [] + t_content_per_sample = [] + for idx in range(batch_size_ct): + if can_try_offset[idx]: + s_mask = s_nonzero[idx] + t_mask = t_nonzero[idx] + s_positions = np.where(s_mask)[0] + t_positions = np.where(t_mask)[0] + s_content_per_sample.append([ + (int(s_off_np[idx, p, 0]), int(s_off_np[idx, p, 1]), int(p)) + for p in s_positions + ]) + t_content_per_sample.append([ + (int(t_off_np[idx, p, 0]), int(t_off_np[idx, p, 1]), int(p)) + for p in t_positions + ]) + else: + s_content_per_sample.append(None) + t_content_per_sample.append(None) + + # --- Char-offset alignment (sequential, fast O(n+m) per sample) --- + aligned_pairs: list[Any] = [None] * batch_size_ct + dp_samples_s = [] + dp_samples_t = [] + dp_slot_indices = [] + + for idx in range(batch_size_ct): + if not can_try_offset[idx]: + dp_samples_s.append(student_align_ids[idx : idx + 1].tolist()) + dp_samples_t.append(teacher_input_ids[idx : idx + 1].tolist()) + dp_slot_indices.append(idx) + continue + + pairs = _align_by_char_offsets( + s_content_per_sample[idx], t_content_per_sample[idx] + ) + n_correct = sum(1 for p in pairs if p[6]) + if n_correct == 0 or n_correct / len(pairs) < 0.5: + dp_samples_s.append(student_align_ids[idx : idx + 1].tolist()) + dp_samples_t.append(teacher_input_ids[idx : idx + 1].tolist()) + dp_slot_indices.append(idx) + else: + aligned_pairs[idx] = pairs + + dp_fallback = len(dp_slot_indices) + n_offsets = batch_size_ct - dp_fallback + + # --- DP alignment for fallbacks (parallelized) --- + if dp_fallback > 0: + if ct_pool is not None and dp_fallback > 1: + chunks = list(zip(dp_samples_s, dp_samples_t)) + dp_results = list(ct_pool.map(_align_chunk, chunks)) + for i, slot in enumerate(dp_slot_indices): + aligned_pairs[slot] = dp_results[i][0] + else: + for i, slot in enumerate(dp_slot_indices): + s_t = torch.tensor(dp_samples_s[i]) + t_t = torch.tensor(dp_samples_t[i]) + if use_align_fast and token_aligner._student_canon_map is not None: + dp_result = token_aligner.align_fast( + s_t, t_t, chunk_size=dp_chunk_size + ) + else: + dp_result = token_aligner.align( + s_t, t_t, chunk_size=dp_chunk_size + ) + aligned_pairs[slot] = dp_result[0] + + _t3 = _time.time() + print( + f" [CT timing] decode={_t1-_t0:.2f}s, " + f"encode={_t2-_t1:.2f}s, " + f"align={_t3-_t2:.2f}s " + f"(offsets: {n_offsets}, " + f"dp_fallback: {dp_fallback})", + flush=True, + ) + + teacher_input_lengths_ct = teacher_attention_mask.sum(dim=1) + + teacher_token_mask = torch.zeros_like(teacher_input_ids, dtype=torch.float32) + for i in range(batch_size_ct): + teacher_token_mask[i, : teacher_input_lengths_ct[i]] = 1.0 + + teacher_data = BatchedDataDict( + { + "input_ids": teacher_input_ids, + "input_lengths": teacher_input_lengths_ct, + "token_mask": teacher_token_mask, + "sample_mask": batch_loss_multiplier, + } + ) + teacher_data.to("cpu") + + return teacher_input_ids, aligned_pairs, teacher_data + + +# =============================================================================== +# Setup & Initialization +# =============================================================================== +def check_vocab_equality( + tokenizer: TokenizerType, student_model_name: str, teacher_model_name: str +) -> None: + """Check if the vocab of the tokenizer (student) and the teacher tokenizer are equal.""" + teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name) + + skip_hint = "Set NRL_SKIP_DISTILLATION_TOKENIZER_CHECK=true to skip this check." + + # 1) Exact token->id mapping equality + vocab_a = tokenizer.get_vocab() + vocab_b = teacher_tokenizer.get_vocab() + assert vocab_a == vocab_b, ( + f"Token->ID mapping differs between student and teacher. {skip_hint}" + ) + + # 2) Size consistency (sanity checks) + assert len(tokenizer) == len(teacher_tokenizer), ( + f"Effective vocab sizes differ between student and teacher. {skip_hint}" + ) + + # 3) Check model.config.vocab_size to guarantee the last dimension of the logits is the same + student_config = AutoConfig.from_pretrained(student_model_name) + teacher_config = AutoConfig.from_pretrained(teacher_model_name) + assert student_config.vocab_size == teacher_config.vocab_size, ( + f"Model config vocab sizes differ between student and teacher. {skip_hint}" + ) + + +def setup( + master_config: OffPolicyMasterConfig, + tokenizer: TokenizerType, + train_dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset] = None, +) -> tuple[ + ColocatablePolicyInterface, # student_policy + ColocatablePolicyInterface, # teacher_policy + StatefulDataLoader, # train_dataloader + Optional[StatefulDataLoader], # val_dataloader + DistillationLossFn, + Logger, + CheckpointManager, + OffPolicyDistillationSaveState, + OffPolicyMasterConfig, +]: + """Setup for off-policy distillation algorithm. + + Key differences from on-policy setup(): + - No student_generation interface (we don't generate responses) + - Simpler cluster setup (training only, no inference cluster needed) + + Returns: + tuple of student_policy, teacher_policy, train_dataloader, val_dataloader, + loss_fn, logger, checkpointer, distillation_save_state, master_config + """ + # Extract configuration + policy_config = master_config["policy"] + teacher_config = master_config["teacher"] + loss_config = master_config["loss_fn"] + distillation_config = master_config["distillation"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + + # Disallow SP + packing for dtensor path + for cfg, who in ((policy_config, "student"), (teacher_config, "teacher")): + dtensor_enabled = cfg["dtensor_cfg"]["enabled"] + sequence_packing_enabled = ( + "sequence_packing" in cfg and cfg["sequence_packing"]["enabled"] + ) + sequence_parallel_enabled = ( + "sequence_parallel" in cfg["dtensor_cfg"] + and cfg["dtensor_cfg"]["sequence_parallel"] + ) + + if dtensor_enabled and sequence_packing_enabled and sequence_parallel_enabled: + raise AssertionError( + f"Distillation does not support DTensor sequence parallel + sequence packing ({who} policy). " + "Please refer to https://github.com/NVIDIA-NeMo/RL/issues/1178 for more details." + ) + + # Set random seed + set_seed(distillation_config["seed"]) + + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + distillation_save_state: Optional[OffPolicyDistillationSaveState] = cast( + Optional[OffPolicyDistillationSaveState], + checkpointer.load_training_info(last_checkpoint_path), + ) + if distillation_save_state is None: + distillation_save_state = _default_distillation_save_state() + + # ========================== + # Data + # ========================== + dataloader = StatefulDataLoader( + train_dataset, + batch_size=distillation_config["num_prompts_per_step"], + shuffle=data_config.get("shuffle", True), + collate_fn=rl_collate_fn, + drop_last=True, + ) + + if last_checkpoint_path: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + dataloader.load_state_dict(dataloader_state_dict) + + print( + f" ✓ Training dataloader loaded with {len(train_dataset)} samples", flush=True + ) + + # Load validation dataloader if provided + val_dataloader: Optional[StatefulDataLoader] = None + val_period = distillation_config.get("val_period", 0) + val_at_start = distillation_config.get("val_at_start", False) + if val_period > 0 or val_at_start: + assert val_dataset is not None, ( + "Validation dataset is required if validation is enabled " + "(val_period > 0 or val_at_start = True)" + ) + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=distillation_config.get( + "val_global_batch_size", distillation_config["num_prompts_per_step"] + ), + shuffle=False, + collate_fn=rl_collate_fn, + drop_last=False, + ) + print( + f" ✓ Validation dataloader loaded with {len(val_dataset)} samples", + flush=True, + ) + + # ========================== + # Cluster + # ========================== + # For off-policy distillation, we only need a training cluster + # No inference cluster needed since we don't generate responses + print("\n▶ Setting up compute cluster...", flush=True) + + cluster = RayVirtualCluster( + name="off_policy_distillation_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=3, + ) + print( + f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes", + flush=True, + ) + + # ========================== + # Cross-Tokenizer Setup + # ========================== + token_aligner_cfg = master_config.get("token_aligner", {}) + cross_tokenizer_enabled = token_aligner_cfg.get("enabled", False) + token_aligner = None + teacher_tokenizer = None + + if cross_tokenizer_enabled: + from nemo_rl.algorithms.x_token import TokenAligner + + print("\n▶ Setting up cross-tokenizer distillation (TokenAligner)...", flush=True) + teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_config["model_name"]) + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + + token_aligner = TokenAligner( + teacher_tokenizer_name=teacher_config["model_name"], + student_tokenizer_name=policy_config["model_name"], + max_comb_len=token_aligner_cfg.get("max_comb_len", 4), + projection_matrix_multiplier=token_aligner_cfg.get( + "projection_matrix_multiplier", 1.0 + ), + ) + token_aligner._load_logits_projection_map( + file_path=token_aligner_cfg["projection_matrix_path"], + use_sparse_format=token_aligner_cfg.get("use_sparse_format", True), + learnable=token_aligner_cfg.get("learnable", False), + device="cpu", + ) + if token_aligner_cfg.get("project_teacher_to_student", False): + token_aligner.create_reverse_projection_matrix(device="cpu") + + print(f" ✓ TokenAligner initialized ({policy_config['model_name']} → {teacher_config['model_name']})", flush=True) + + token_aligner.precompute_canonical_maps() + if token_aligner_cfg.get("use_cuda_dp", False): + cuda_dp_path = Path(__file__).resolve().parents[2] / "x_token" / "cuda_tokenalign_dp.py" + if not cuda_dp_path.exists(): + raise FileNotFoundError( + f"Requested token_aligner.use_cuda_dp=true but file not found: {cuda_dp_path}" + ) + spec = importlib.util.spec_from_file_location("x_token_cuda_dp", str(cuda_dp_path)) + if spec is None or spec.loader is None: + raise ImportError(f"Failed to load CUDA DP module from: {cuda_dp_path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + mod.monkeypatch_tokenaligner_cuda_basecase() + token_aligner._use_cuda_dp = True + token_aligner._cuda_dp_module_path = str(cuda_dp_path) + print(" ✓ CUDA DP monkeypatch enabled for TokenAligner", flush=True) + if token_aligner_cfg.get("force_dp_only", False): + print(" ✓ force_dp_only enabled (char-offset disabled)", flush=True) + + # ========================== + # Teacher Policy + # ========================== + print("\n▶ Setting up teacher policy...", flush=True) + + if not cross_tokenizer_enabled: + if not bool(os.getenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", False)): + check_vocab_equality( + tokenizer, policy_config["model_name"], teacher_config["model_name"] + ) + + if "megatron_cfg" in teacher_config and teacher_config["megatron_cfg"]["enabled"]: + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(dataloader), + ) + teacher_config["megatron_cfg"]["train_iters"] = total_train_iters + + teacher_policy = Policy( + name_prefix="teacher", + cluster=cluster, + config=teacher_config, + tokenizer=teacher_tokenizer if cross_tokenizer_enabled else tokenizer, + weights_path=None, + optimizer_path=None, + init_optimizer=False, + init_reference_model=False, + ) + teacher_policy.offload_after_refit() + + # ========================== + # Student Policy + # ========================== + # Note: No student_generation interface for off-policy distillation + print("\n▶ Setting up student policy...", flush=True) + + # Checkpoint paths + weights_path = None + optimizer_path = None + if last_checkpoint_path: + weights_path = Path(last_checkpoint_path) / "policy" / "weights" + optimizer_path = Path(last_checkpoint_path) / "policy" / "optimizer" + + if "megatron_cfg" in policy_config and policy_config["megatron_cfg"]["enabled"]: + ## NOTE: this is equal to the total number of scheduler steps + total_train_iters = min( + distillation_config["max_num_steps"], + distillation_config["max_num_epochs"] * len(dataloader), + ) + policy_config["megatron_cfg"]["train_iters"] = total_train_iters + + student_policy = Policy( + name_prefix="student", + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + init_optimizer=True, + init_reference_model=False, + ) + + if cross_tokenizer_enabled: + loss_fn = CrossTokenizerDistillationLossFn(loss_config, token_aligner) + else: + loss_fn = DistillationLossFn(loss_config) + + print("\n" + "=" * 60) + print(" " * 12 + "OFF-POLICY DISTILLATION SETUP COMPLETE") + print("=" * 60 + "\n", flush=True) + + return ( + student_policy, + teacher_policy, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + distillation_save_state, + master_config, + ) + + +# =============================================================================== +# Training +# =============================================================================== + + +def validate( + student_policy: ColocatablePolicyInterface, + teacher_policy: ColocatablePolicyInterface, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: DistillationLossFn, + step: int, + master_config: OffPolicyMasterConfig, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Run validation on the validation dataset for off-policy distillation. + + Computes teacher top-k logits and student distillation loss on validation data + in eval mode (no gradient updates). + + Args: + student_policy: The student policy to evaluate. + teacher_policy: The teacher policy for computing target logits. + val_dataloader: Validation dataloader. + tokenizer: Tokenizer for processing text. + loss_fn: Distillation loss function. + step: Current training step (for logging). + master_config: Master configuration dictionary. + + Returns: + Tuple of (val_metrics, timing_metrics). + """ + if val_dataloader is None: + print(" ⚠️ No validation dataloader provided, skipping validation", flush=True) + return {}, {} + + timer = Timer() + + with timer.time("total_validation_time"): + print(f"▶ Starting validation at step {step}...", flush=True) + + val_metrics: dict[str, Any] = {"val_loss": 0.0} + sum_num_valid_tokens = 0 + + val_batches = master_config["distillation"].get("val_batches", 0) + val_batch_size = master_config["distillation"].get( + "val_global_batch_size", + master_config["distillation"]["num_prompts_per_step"], + ) + val_mbs = master_config["distillation"].get( + "val_micro_batch_size", val_batch_size + ) + + for batch_idx, val_batch in enumerate(val_dataloader): + # Add loss masks for assistant tokens + for message_log in val_batch["message_log"]: + for message in message_log: + if "token_loss_mask" not in message: + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + + # Flatten messages + flat_messages, input_lengths = batched_message_log_to_flat_message( + val_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"].get( + "make_sequence_length_divisible_by", 1 + ), + ) + + val_data = BatchedDataDict[DistillationLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": val_batch["loss_multiplier"], + } + ) + val_data.update(flat_messages.get_multimodal_dict(as_tensors=False)) + val_data.to("cpu") + + # Pad partial batch if needed (drop_last=False for val) + # Must pad BEFORE teacher logits to avoid size mismatch: + # teacher.get_topk_logits internally pads for its own DP sharding + # and returns padded-size outputs, so all inputs must be + # uniformly padded first. + if val_data.size < val_batch_size: + dp_size = student_policy.sharding_annotations.get_axis_size( + "data_parallel" + ) + val_data = maybe_pad_last_batch(val_data, dp_size, val_mbs) + + # Get teacher top-k logits + use_ipc = master_config["distillation"].get("use_ipc", True) + topk_k = master_config["distillation"]["topk_logits_k"] + + teacher_policy.prepare_for_lp_inference() + if use_ipc: + teacher_logits = teacher_policy.train( + val_data, + None, + eval_mode=True, + is_teacher=True, + topk_logits=topk_k, + gbs=val_data.size, + mbs=master_config["distillation"].get( + "val_micro_batch_size", + master_config["distillation"].get( + "val_global_batch_size", + master_config["distillation"]["num_prompts_per_step"], + ), + ), + ) + else: + teacher_topk = teacher_policy.get_topk_logits(val_data, k=topk_k) + val_data["teacher_topk_logits"] = teacher_topk["topk_logits"] + val_data["teacher_topk_indices"] = teacher_topk["topk_indices"] + del teacher_topk + teacher_policy.offload_after_refit() + + # Compute student validation loss (eval mode, no gradient updates) + student_policy.prepare_for_training() + if use_ipc: + val_results = student_policy.train( + val_data, + loss_fn, + eval_mode=True, + gbs=val_data.size, + mbs=val_mbs, + teacher_logits=teacher_logits, + ) + del teacher_logits + else: + val_results = student_policy.train( + val_data, + loss_fn, + eval_mode=True, + gbs=val_data.size, + mbs=val_mbs, + ) + + if len(val_results["all_mb_metrics"]) == 0: + warnings.warn( + "No validation metrics were collected for this batch." + " This is likely because there were no valid samples." + ) + else: + num_valid_tokens = ( + val_data["sample_mask"].unsqueeze(-1) * val_data["token_mask"] + ).sum() + val_metrics["val_loss"] += float(val_results["loss"]) * num_valid_tokens + sum_num_valid_tokens += num_valid_tokens + + if val_batches > 0 and batch_idx >= val_batches - 1: + break + + if sum_num_valid_tokens > 0: + val_metrics["val_loss"] /= sum_num_valid_tokens + else: + warnings.warn( + "No validation metrics were collected." + " This is likely because there were no valid samples in the validation set." + ) + + student_policy.prepare_for_training() + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + if sum_num_valid_tokens > 0: + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Validation loss: {val_metrics['val_loss']:.4f}") + + # Print timing information + print("\n ⏱️ Validation Timing:") + print(f" • Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics + + +def off_policy_distillation_train( + student_policy: ColocatablePolicyInterface, + teacher_policy: ColocatablePolicyInterface, + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + tokenizer: TokenizerType, + loss_fn: DistillationLossFn, + logger: Logger, + checkpointer: CheckpointManager, + distillation_save_state: OffPolicyDistillationSaveState, + master_config: OffPolicyMasterConfig, + eval_hook: Optional[Callable] = None, + eval_hook_period: int = 0, + eval_hook_at_start: bool = False, + token_aligner=None, + teacher_tokenizer=None, +) -> None: + """Run off-policy distillation training algorithm. + + Key differences from on-policy distillation_train(): + - No student_generation parameter (we don't generate responses) + - No task_to_env / val_task_to_env (no environment scoring) + - No rollout generation step - uses fixed responses from dataset directly + + Training loop: + 1. Load batch with prompt-response pairs (responses already in dataset) + 2. Add loss masks (train on assistant tokens only) + 3. Get teacher top-k logits for the fixed responses + 4. Train student with KL divergence loss + + Args: + eval_hook: Optional callback ``(step, student_policy, teacher_policy, logger) -> dict`` + called every *eval_hook_period* steps. Return value (if dict) is + logged under ``prefix="eval_hook"`` and used for checkpoint metric lookup. + eval_hook_period: How often (in steps) to call *eval_hook*. 0 = disabled. + eval_hook_at_start: If True, call eval_hook before the first training step. + """ + timer = Timer() + timeout = TimeoutChecker( + timeout=master_config["checkpointing"].get("checkpoint_must_save_by", None), + fit_last_save_time=True, + ) + timeout.start_iterations() + + # common config/state items + current_epoch = distillation_save_state["current_epoch"] # current epoch + current_step = distillation_save_state[ + "current_step" + ] # current step within current epoch + total_steps = distillation_save_state[ + "total_steps" + ] # total number of steps across all epochs + consumed_samples = distillation_save_state["consumed_samples"] + total_valid_tokens = distillation_save_state["total_valid_tokens"] + max_epochs = master_config["distillation"][ + "max_num_epochs" + ] # max number of epochs to train for + max_steps = master_config["distillation"][ + "max_num_steps" + ] # max number of steps to train for + + # Validation configuration + val_period = master_config["distillation"].get("val_period", 0) + val_at_start = master_config["distillation"].get("val_at_start", False) + + # Run validation at the start if configured + if val_at_start and total_steps == 0: + print("\n🔍 Running initial validation...", flush=True) + val_metrics, validation_timings = validate( + student_policy, + teacher_policy, + val_dataloader, + tokenizer, + loss_fn, + step=0, + master_config=master_config, + ) + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") + + # Run eval hook at start if configured + eval_hook_metrics = None + if eval_hook and eval_hook_at_start and total_steps == 0: + print("\n🔍 Running initial eval hook...", flush=True) + eval_hook_metrics = eval_hook( + step=0, + student_policy=student_policy, + teacher_policy=teacher_policy, + logger=logger, + ) + if isinstance(eval_hook_metrics, dict): + logger.log_metrics(eval_hook_metrics, 0, prefix="eval_hook") + + # Run off-policy distillation training + batch: BatchedDataDict[DatumSpec] + + # Create a process pool for cross-tokenizer processing (if enabled). + cross_tokenizer_enabled = token_aligner is not None and teacher_tokenizer is not None + token_aligner_cfg = master_config.get("token_aligner", {}) + dp_chunk_size = int(token_aligner_cfg.get("dp_chunk_size", 128)) + # Default to DP-only for parity/stability; char-offset is opt-in. + use_char_offset = bool(token_aligner_cfg.get("use_char_offset", False)) + if bool(token_aligner_cfg.get("force_dp_only", False)): + # Backward-compatible override for older configs. + use_char_offset = False + use_align_fast = bool(token_aligner_cfg.get("use_align_fast", False)) + ct_num_workers = master_config["distillation"].get("cross_tokenizer_num_workers", None) + if ct_num_workers is None: + ct_num_workers = os.cpu_count() or 1 + ct_pool: Optional[ProcessPoolExecutor] = None + if cross_tokenizer_enabled and ct_num_workers > 1: + mp_ctx = multiprocessing.get_context("forkserver") + ct_pool = ProcessPoolExecutor( + max_workers=ct_num_workers, + mp_context=mp_ctx, + initializer=_init_align_worker, + initargs=(token_aligner, dp_chunk_size, use_align_fast), + ) + print( + f" ✓ Cross-tokenizer process pool created with {ct_num_workers} workers " + f"(dp_chunk_size={dp_chunk_size})", + flush=True, + ) + if cross_tokenizer_enabled: + print(f" ✓ TokenAligner mode: use_char_offset={use_char_offset}", flush=True) + print(f" ✓ TokenAligner DP mode: use_align_fast={use_align_fast}", flush=True) + + ct_prefetch_pool: Optional[ThreadPoolExecutor] = None + if cross_tokenizer_enabled: + ct_prefetch_pool = ThreadPoolExecutor(max_workers=1) + + def _shutdown_alignment_pools() -> None: + if ct_prefetch_pool is not None: + ct_prefetch_pool.shutdown(wait=False, cancel_futures=True) + if ct_pool is not None: + ct_pool.shutdown(wait=False) + + def _prepare_train_batch_data(batch_obj: BatchedDataDict[DatumSpec]): + # Add loss mask for each message (train on assistant tokens only) + # Skip if token_loss_mask already exists from data processor + for message_log in batch_obj["message_log"]: + for message in message_log: + if "token_loss_mask" not in message: + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + + # Convert message_log to flat format for training + flat_messages_obj, input_lengths_obj = batched_message_log_to_flat_message( + batch_obj["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"].get( + "make_sequence_length_divisible_by", 1 + ), + ) + + train_data_obj = BatchedDataDict[DistillationLossDataDict]( + { + "input_ids": flat_messages_obj["token_ids"], + "input_lengths": input_lengths_obj, + "token_mask": flat_messages_obj["token_loss_mask"], + "sample_mask": batch_obj["loss_multiplier"], + } + ) + train_data_obj.update( + flat_messages_obj.get_multimodal_dict(as_tensors=False) + ) + train_data_obj.to("cpu") + return flat_messages_obj, input_lengths_obj, train_data_obj + + def _get_max_teacher_len() -> int: + return int( + master_config["teacher"].get( + "max_total_sequence_length", + master_config["policy"]["max_total_sequence_length"], + ) + ) + + def _resolve_cross_tokenizer_batch_data( + train_data_obj: BatchedDataDict[DistillationLossDataDict], + batch_obj: BatchedDataDict[DatumSpec], + ct_future_obj: Any, + ) -> tuple[torch.Tensor, list[Any], BatchedDataDict]: + if ct_future_obj is not None: + return ct_future_obj.result() + return _process_cross_tokenizer_batch( + train_input_ids=train_data_obj["input_ids"], + batch_loss_multiplier=batch_obj["loss_multiplier"], + extra_env=batch_obj.get("extra_env_info"), + tokenizer=tokenizer, + teacher_tokenizer=teacher_tokenizer, + token_aligner=token_aligner, + use_char_offset=use_char_offset, + use_align_fast=use_align_fast, + dp_chunk_size=dp_chunk_size, + ct_pool=ct_pool, + max_teacher_len_rt=_get_max_teacher_len(), + ) + + def _maybe_prefetch_next_batch( + dataloader_iter_obj: Any, + ) -> Optional[_PrefetchedBatchPack]: + if not cross_tokenizer_enabled or ct_prefetch_pool is None: + return None + + try: + next_batch_obj = next(dataloader_iter_obj) + except StopIteration: + return None + + next_flat_messages, next_input_lengths, next_train_data = _prepare_train_batch_data( + next_batch_obj + ) + next_ct_future = ct_prefetch_pool.submit( + _process_cross_tokenizer_batch, + next_train_data["input_ids"], + next_batch_obj["loss_multiplier"], + next_batch_obj.get("extra_env_info"), + tokenizer, + teacher_tokenizer, + token_aligner, + use_char_offset, + use_align_fast, + dp_chunk_size, + ct_pool, + _get_max_teacher_len(), + ) + return { + "batch": next_batch_obj, + "flat_messages": next_flat_messages, + "input_lengths": next_input_lengths, + "train_data": next_train_data, + "ct_future": next_ct_future, + } + + while total_steps < max_steps and current_epoch < max_epochs: + print( + f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_epochs} {'=' * 25}", + flush=True, + ) + + dataloader_iter = iter(dataloader) + prefetched_batch_pack: Optional[_PrefetchedBatchPack] = None + while total_steps < max_steps: + if prefetched_batch_pack is not None: + batch = prefetched_batch_pack["batch"] + flat_messages = prefetched_batch_pack["flat_messages"] + input_lengths = prefetched_batch_pack["input_lengths"] + train_data = prefetched_batch_pack["train_data"] + ct_future = prefetched_batch_pack["ct_future"] + prefetched_batch_pack = None + loaded_from_prefetch = True + else: + try: + batch = next(dataloader_iter) + except StopIteration: + break + loaded_from_prefetch = False + ct_future = None + + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_steps)} {'=' * 25}", + flush=True, + ) + maybe_gpu_profile_step(student_policy, total_steps + 1) + val_metrics, validation_timings = None, None + + with timer.time("total_step_time"): + # ==== Data Processing ==== + # Off-policy: Use responses from dataset directly (no generation) + if not loaded_from_prefetch: + print("▶ Processing batch data (off-policy - using fixed responses)...", flush=True) + with timer.time("data_processing"): + flat_messages, input_lengths, train_data = _prepare_train_batch_data(batch) + else: + print("▶ Using prefetched batch data...", flush=True) + # Keep timing key stable in logs when using prefetched data. + with timer.time("data_processing"): + pass + + # ==== Cross-Tokenizer Data Processing ==== + teacher_data = None + + if cross_tokenizer_enabled: + with timer.time("cross_tokenizer_processing"): + teacher_input_ids, aligned_pairs, teacher_data = ( + _resolve_cross_tokenizer_batch_data( + train_data_obj=train_data, + batch_obj=batch, + ct_future_obj=ct_future, + ) + ) + + loss_fn.set_cross_tokenizer_data( + teacher_input_ids=teacher_input_ids, + aligned_pairs=aligned_pairs, + ) + + # Prepare one-step-ahead cross-tokenizer preprocessing in the background. + if ( + cross_tokenizer_enabled + and prefetched_batch_pack is None + ): + prefetched_batch_pack = _maybe_prefetch_next_batch(dataloader_iter) + + # ==== Teacher Logprob Inference ==== + use_ipc = master_config["distillation"].get("use_ipc", True) + topk_k = master_config["distillation"]["topk_logits_k"] + + print("▶ Preparing for teacher logprob inference...", flush=True) + with timer.time("teacher_logprob_inference_prep"): + student_policy.offload_after_refit() + teacher_policy.prepare_for_lp_inference() + + teacher_fwd_data = teacher_data if cross_tokenizer_enabled else train_data + teacher_topk_k = None if cross_tokenizer_enabled else topk_k + + if use_ipc: + print("▶ Computing teacher logprobs (IPC)...", flush=True) + with timer.time("teacher_logprob_inference"): + teacher_logits = teacher_policy.train( + teacher_fwd_data, + None, + eval_mode=True, + is_teacher=True, + topk_logits=teacher_topk_k, + gbs=master_config["policy"]["train_global_batch_size"], + mbs=master_config["policy"]["train_micro_batch_size"], + ) + else: + if cross_tokenizer_enabled: + raise NotImplementedError( + "Cross-tokenizer distillation requires use_ipc=True. " + "Set distillation.use_ipc: true in the config." + ) + print("▶ Computing teacher logprobs (non-IPC, data dict)...", flush=True) + with timer.time("teacher_logprob_inference"): + teacher_topk = teacher_policy.get_topk_logits(train_data, k=topk_k) + train_data["teacher_topk_logits"] = teacher_topk["topk_logits"] + train_data["teacher_topk_indices"] = teacher_topk["topk_indices"] + del teacher_topk + + # ==== Student Training ==== + print("▶ Preparing for training...", flush=True) + with timer.time("training_prep"): + teacher_policy.offload_after_refit() + student_policy.prepare_for_training() + + if cross_tokenizer_enabled: + if not getattr(student_policy, '_loss_fn_initialized', False): + student_policy._loss_fn_initialized = True + token_aligner_cfg = master_config.get("token_aligner", {}) + student_policy.init_cross_tokenizer_loss_fn( + loss_config=master_config["loss_fn"], + token_aligner_config={ + "teacher_model": master_config["teacher"]["model_name"], + "student_model": master_config["policy"]["model_name"], + "projection_matrix_path": token_aligner_cfg["projection_matrix_path"], + "use_sparse_format": token_aligner_cfg.get("use_sparse_format", True), + "learnable": token_aligner_cfg.get("learnable", False), + "max_comb_len": token_aligner_cfg.get("max_comb_len", 4), + "projection_matrix_multiplier": token_aligner_cfg.get("projection_matrix_multiplier", 1.0), + "project_teacher_to_student": token_aligner_cfg.get("project_teacher_to_student", False), + }, + ) + student_policy.update_cross_tokenizer_data( + teacher_input_ids=teacher_input_ids, + aligned_pairs=aligned_pairs, + ) + + student_loss_fn = None if cross_tokenizer_enabled else loss_fn + print("▶ Training policy...", flush=True) + with timer.time("policy_training"): + if use_ipc: + train_results = student_policy.train( + train_data, student_loss_fn, teacher_logits=teacher_logits + ) + del teacher_logits + else: + train_results = student_policy.train(train_data, student_loss_fn) + + is_last_step = (total_steps + 1 >= max_steps) or ( + (current_epoch + 1 == max_epochs) + and (current_step + 1 == len(dataloader)) + ) + + # ==== Validation ==== + if val_period > 0 and (total_steps + 1) % val_period == 0: + val_metrics, validation_timings = validate( + student_policy, + teacher_policy, + val_dataloader, + tokenizer, + loss_fn, + step=total_steps + 1, + master_config=master_config, + ) + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" + ) + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + + # ==== Eval Hook (e.g., generation-based MATH/MMLU eval) ==== + if eval_hook and eval_hook_period > 0 and (total_steps + 1) % eval_hook_period == 0: + print(f"\n🔍 Running eval hook at step {total_steps + 1}...", flush=True) + with timer.time("eval_hook"): + eval_hook_metrics = eval_hook( + step=total_steps + 1, + student_policy=student_policy, + teacher_policy=teacher_policy, + logger=logger, + ) + if isinstance(eval_hook_metrics, dict): + logger.log_metrics(eval_hook_metrics, total_steps + 1, prefix="eval_hook") + student_policy.prepare_for_training() + + # ==== Metrics ==== + metrics = { + "loss": train_results["loss"].numpy(), + "grad_norm": train_results["grad_norm"].numpy(), + "mean_seq_length": batch["length"].numpy().mean(), + "total_num_tokens": input_lengths.numpy().sum(), + } + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k in { + "lr", + "wd", + "global_valid_seqs", + "global_valid_toks", + "mean_seq_length", + }: + metrics[k] = np.mean(v).item() + else: + metrics[k] = np.sum(v).item() + total_valid_tokens += metrics["global_valid_toks"] + + ## Checkpointing + consumed_samples += master_config["distillation"][ + "num_prompts_per_step" + ] + timeout.mark_iteration() + + should_save_by_step = ( + is_last_step + or (total_steps + 1) % master_config["checkpointing"]["save_period"] + == 0 + ) + # Check if timeout-based checkpointing is enabled in config. + should_save_by_timeout = timeout.check_save() + + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): + student_policy.prepare_for_training() + + distillation_save_state["current_epoch"] = current_epoch + distillation_save_state["current_step"] = current_step + 1 + distillation_save_state["total_steps"] = total_steps + 1 + distillation_save_state["total_valid_tokens"] = total_valid_tokens + distillation_save_state["consumed_samples"] = consumed_samples + + full_metric_name = master_config["checkpointing"]["metric_name"] + if full_metric_name is not None: + assert full_metric_name.startswith( + "train:" + ) or full_metric_name.startswith("val:"), ( + f"metric_name={full_metric_name} must start with 'val:' or 'train:',\n" + f'followed by the corresponding name in the "val" or "train" metrics dictionary. ' + f"Example: 'train:loss' or 'val:val_loss'" + ) + prefix, metric_name = full_metric_name.split(":", 1) + metrics_source = metrics if prefix == "train" else val_metrics + if not metrics_source: + warnings.warn( + f"You asked to save checkpoints based on {metric_name} but no {prefix} metrics were collected. " + "This checkpoint will not be saved as top-k.", + stacklevel=2, + ) + if full_metric_name in distillation_save_state: + del distillation_save_state[full_metric_name] + elif metric_name not in metrics_source: + raise ValueError( + f"Metric {metric_name} not found in {prefix} metrics" + ) + else: + distillation_save_state[full_metric_name] = metrics_source[ + metric_name + ] + + with timer.time("checkpointing"): + print( + f"Saving checkpoint for step {total_steps + 1}...", + flush=True, + ) + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, distillation_save_state, master_config + ) + student_policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + checkpointing_cfg=master_config["checkpointing"], + ) + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + # Logging + # Log training data + log_data = {"content": flat_messages["content"]} + log_data["input_lengths"] = input_lengths.tolist() + logger.log_batched_dict_as_jsonl( + log_data, f"train_data_step{total_steps + 1}.jsonl" + ) + + timing_metrics: dict[str, float] = timer.get_timing_metrics( + reduction_op="sum" + ) # type: ignore + + print("\n📊 Training Results:") + + print(f" • Loss: {metrics['loss']:.4f}") + print(f" • Grad Norm: {metrics['grad_norm']:.4f}") + print(f" • Mean Sequence Length: {metrics['mean_seq_length']:.1f}") + + if "total_flops" in train_results: + total_time = timing_metrics.get("total_step_time", 0) + total_tflops = ( + train_results["total_flops"] + / timing_metrics["policy_training"] + / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", + flush=True, + ) + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + flush=True, + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops + + print("\n⏱️ Timing:", flush=True) + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + + total_num_gpus = ( + master_config["cluster"]["num_nodes"] + * master_config["cluster"]["gpus_per_node"] + ) + metrics.update( + { + "tokens_per_sec_per_gpu": metrics["total_num_tokens"] + / total_time + / total_num_gpus + } + ) + + print(f" • Total step time: {total_time:.2f}s", flush=True) + + # Display all other timing metrics + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) + + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics["global_valid_toks"] / total_time / total_num_gpus + ) + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) + _shutdown_alignment_pools() + return + if total_steps >= max_steps: + print( + "Max number of steps has been reached, stopping training early", + flush=True, + ) + _shutdown_alignment_pools() + return + + # End of epoch + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch + + _shutdown_alignment_pools() diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 1011761e33..57d158cb4d 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -13,18 +13,27 @@ # limitations under the License. import os import warnings -from typing import NotRequired, Optional, TypedDict, cast +from pathlib import Path +import sys +if sys.version_info >= (3, 11): + from typing import NotRequired, Optional, TypedDict, cast +else: + from typing import Optional, TypedDict, cast + from typing_extensions import NotRequired import numpy as np import torch from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer, PreTrainedTokenizerBase -from nemo_rl.algorithms.loss.loss_functions import NLLLossFn +from nemo_rl.algorithms.loss_functions import ( + NLLLoss, +) from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import TaskDataSpec from nemo_rl.data.llm_message_utils import ( add_loss_mask_to_message_log, batched_message_log_to_flat_message, @@ -67,9 +76,6 @@ class SFTConfig(TypedDict): val_global_batch_size: int val_micro_batch_size: int val_at_start: bool - # Whether to run validation on the last training step. Setting this to True ensures the - # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). - val_at_end: bool seed: int @@ -89,13 +95,13 @@ def setup( master_config: MasterConfig, tokenizer: AutoTokenizer, train_dataset: AllTaskProcessedDataset, - val_dataset: Optional[AllTaskProcessedDataset], + val_dataset: AllTaskProcessedDataset, ) -> tuple[ Policy, RayVirtualCluster, StatefulDataLoader, - Optional[StatefulDataLoader], - NLLLossFn, + StatefulDataLoader, + NLLLoss, Logger, CheckpointManager, SFTSaveState, @@ -148,17 +154,14 @@ def setup( ) train_dataloader.load_state_dict(dataloader_state_dict) - if val_dataset is not None: - val_dataloader = StatefulDataLoader( - val_dataset, - batch_size=sft_config["val_global_batch_size"], - shuffle=False, - collate_fn=rl_collate_fn, - drop_last=False, - num_workers=data_config["num_workers"], - ) - else: - val_dataloader = None + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=sft_config["val_global_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + drop_last=False, + num_workers=data_config["num_workers"], + ) # ========================== # Cluster @@ -190,25 +193,24 @@ def setup( processor = tokenizer tokenizer = processor.tokenizer - weights_path, optimizer_path = checkpointer.get_resume_paths(last_checkpoint_path) - policy = Policy( cluster=cluster, config=policy_config, tokenizer=tokenizer, processor=processor, - weights_path=weights_path, - optimizer_path=optimizer_path, + weights_path=Path(last_checkpoint_path) / "policy" / "weights" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" + if last_checkpoint_path + else None, init_optimizer=True, init_reference_model=False, ) # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = NLLLossFn( - use_linear_ce_fusion=policy_config["megatron_cfg"]["enabled"] - and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"] - ) + loss_fn = NLLLoss() print(" ✓ Model initialized") print("\n" + "=" * 60) @@ -233,22 +235,23 @@ def setup( # ======================================================= def validate( policy: PolicyInterface, - val_dataloader: Optional[StatefulDataLoader], + val_dataloader: StatefulDataLoader, tokenizer, loss_fn, step: int, master_config: MasterConfig, + sft_task_spec: TaskDataSpec, val_batches: int, val_batch_size: int, val_mbs: int, ): """Run validation on the validation dataset.""" if val_dataloader is None: - assert master_config["sft"]["val_period"] <= 0, ( - "val_dataloader is None, so sft.val_period must be <= 0" + assert val_dataloader is not None or master_config["dpo"]["val_period"] == 0, ( + "val_dataloader is None, so dpo.val_period must be 0" ) print(" ⚠️ No validation dataloader provided, skipping validation") - return {}, {} + return timer = Timer() @@ -357,6 +360,7 @@ def sft_train( loss_fn, master_config, logger, + sft_task_spec, checkpointer, sft_save_state: SFTSaveState, ) -> None: @@ -386,7 +390,6 @@ def sft_train( # Validation configuration val_period = sft_config["val_period"] val_at_start = sft_config["val_at_start"] - val_at_end = sft_config["val_at_end"] max_num_epochs = sft_config["max_num_epochs"] # Run validation at the start if configured @@ -399,6 +402,7 @@ def sft_train( loss_fn, step=0, master_config=master_config, + sft_task_spec=sft_task_spec, val_batches=sft_config["val_batches"], val_batch_size=sft_config["val_global_batch_size"], val_mbs=sft_config["val_micro_batch_size"], @@ -454,11 +458,7 @@ def sft_train( print("▶ Taking a training step...") with timer.time("policy_training"): - train_results = policy.train( - train_data, - loss_fn, - timer=timer, - ) + train_results = policy.train(train_data, loss_fn) is_last_step = total_steps + 1 >= master_config["sft"][ "max_num_steps" @@ -467,10 +467,8 @@ def sft_train( and current_step + 1 == len(train_dataloader) ) - # Run validation if it's a validation step or last step with val_at_end - if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( - val_at_end and is_last_step - ): + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: val_metrics, validation_timings = validate( policy, val_dataloader, @@ -478,6 +476,7 @@ def sft_train( loss_fn, step=total_steps + 1, master_config=master_config, + sft_task_spec=sft_task_spec, val_batches=sft_config["val_batches"], val_batch_size=sft_config["val_global_batch_size"], val_mbs=sft_config["val_micro_batch_size"], @@ -502,7 +501,7 @@ def sft_train( metrics[k] = np.mean(v).item() else: metrics[k] = np.sum(v).item() - total_valid_tokens += metrics.get("global_valid_toks", 0) + total_valid_tokens += metrics["global_valid_toks"] ## Checkpointing sft_save_state["consumed_samples"] += master_config["policy"][ @@ -560,15 +559,14 @@ def sft_train( checkpoint_path = checkpointer.init_tmp_checkpoint( total_steps + 1, sft_save_state, master_config ) + policy.save_checkpoint( weights_path=os.path.join( checkpoint_path, "policy", "weights" ), optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" - ) - if checkpointer.save_optimizer - else None, + ), tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), @@ -617,12 +615,9 @@ def sft_train( master_config["cluster"]["num_nodes"] * master_config["cluster"]["gpus_per_node"] ) - if total_time > 0: - timing_metrics["valid_tokens_per_sec_per_gpu"] = ( - metrics.get("global_valid_toks", 0) / total_time / total_num_gpus - ) - else: - timing_metrics["valid_tokens_per_sec_per_gpu"] = 0.0 + timing_metrics["valid_tokens_per_sec_per_gpu"] = ( + metrics["global_valid_toks"] / total_time / total_num_gpus + ) logger.log_metrics(metrics, total_steps + 1, prefix="train") logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") diff --git a/nemo_rl/algorithms/x_token/__init__.py b/nemo_rl/algorithms/x_token/__init__.py new file mode 100644 index 0000000000..f78d230fc4 --- /dev/null +++ b/nemo_rl/algorithms/x_token/__init__.py @@ -0,0 +1,3 @@ +from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + +__all__ = ["TokenAligner"] diff --git a/nemo_rl/algorithms/x_token/minimal_projection_generator.py b/nemo_rl/algorithms/x_token/minimal_projection_generator.py new file mode 100644 index 0000000000..89711f7a4a --- /dev/null +++ b/nemo_rl/algorithms/x_token/minimal_projection_generator.py @@ -0,0 +1,572 @@ +import torch +import os +import argparse +from transformers import AutoTokenizer, AutoModel, AutoConfig +# from sentence_transformers import SentenceTransformer +from tqdm.auto import tqdm +import re +import pdb + + +##### verify KL and top5 with this matrix + + +###### use config vocab size, not tokenizer + +EXACT_MATCH_ONLY = False + +# --- Configuration and Setup --- +parser = argparse.ArgumentParser(description="Generate a sparse projection map between two tokenizers.") +parser.add_argument("--model_a_index", type=int, default=1, help="Index of the source model (Model A / Student).") +parser.add_argument("--model_b_index", type=int, default=0, help="Index of the target model (Model B / Teacher).") +parser.add_argument("--model_a_name", type=str, default=None, help="HuggingFace model name for source model (Model A / Student). If provided, overrides model_a_index.") +parser.add_argument("--model_b_name", type=str, default=None, help="HuggingFace model name for target model (Model B / Teacher). If provided, overrides model_b_index.") +parser.add_argument("--keep_top_tokens", type=int, default=-1, help="Number of top tokens to keep for each vocabulary. -1 means all.") +parser.add_argument("--data_dir", type=str, default="cross_tokenizer_data/", help="Directory for importance scores and cached data.") +parser.add_argument("--top_k", type=int, default=10, help="Number of top projections to keep for each token.") +parser.add_argument("--weight_threshold", type=float, default=0.0, help="Minimum weight threshold to keep a projection. Values below this will be filtered out.") +parser.add_argument("--force_recompute", action='store_true', help="Force recomputation of embeddings even if cached files exist.") +parser.add_argument("--skip_exact_enforcement", action='store_true', help="Skip enforcing exact matches between tokens.") +parser.add_argument("--use_canonicalization", action='store_true', help="Apply token canonicalization before generating embeddings to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n).") +args = parser.parse_args() + +args.skip_exact_enforcement = True + +MODEL_LIST = [ + "nvidia/Mistral-NeMo-Minitron-8B-Base", + "Qwen/Qwen3-8B-Base", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.1-8B", + "google/gemma-3-4b-it", + "google/gemma-2b", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "openai/gpt-oss-20b", + "microsoft/phi-4", + "google/gemma-3-12b-pt", +] +EMBEDDING_MODEL_CHOICES = [ + {"name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-mpnet-base-v2", "type": "sbert"}, + {"name": "sentence-transformers/all-MiniLM-L6-v2", "type": "sbert"}, + {"name": "Qwen/Qwen3-Embedding-4B", "type": "llm_first_layer"}, + {"name": "Qwen/Qwen3-Embedding-0.6B", "type": "llm_first_layer"}, +] + +MAX_SEQ_LENGTH_EMBEDDING = 64 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def sinkhorn(A, n_iters=10): + for _ in range(n_iters): + if _ % 2 == 0: + # A = A / (A.sum(dim=0, keepdim=True) + 1e-6) + col_sums = A.sum(dim=0, keepdim=True) + safe_col_sums = torch.where(col_sums == 0, torch.ones_like(col_sums), col_sums) + A = A / safe_col_sums + else: + #0, 2, 4, 6 + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +# --- Helper Functions --- + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + if 'mini' in name: + cleaned_name += "_mini" + return cleaned_name + +def load_tokenizer(model_id_or_path): + """Loads a HuggingFace tokenizer, setting a pad token if necessary.""" + try: + tok = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + except Exception as e: + print(f"Error loading tokenizer for model '{model_id_or_path}': {e}") + print(f"Available models in MODEL_LIST (indices 0-{len(MODEL_LIST)-1}):") + for i, model in enumerate(MODEL_LIST): + print(f" {i}: {model}") + raise + +def validate_model_selection(args): + """Validates that the model selection arguments are valid.""" + # Check if both name and index are provided for the same model + if args.model_a_name is not None and args.model_a_index != 1: # 1 is the default + print("Warning: Both --model_a_name and --model_a_index provided. Using --model_a_name.") + + if args.model_b_name is not None and args.model_b_index != 0: # 0 is the default + print("Warning: Both --model_b_name and --model_b_index provided. Using --model_b_name.") + + # Validate indices if names are not provided + if args.model_a_name is None: + if args.model_a_index < 0 or args.model_a_index >= len(MODEL_LIST): + raise ValueError(f"model_a_index {args.model_a_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + if args.model_b_name is None: + if args.model_b_index < 0 or args.model_b_index >= len(MODEL_LIST): + raise ValueError(f"model_b_index {args.model_b_index} is out of range. Available models: 0-{len(MODEL_LIST)-1}") + + # Check if the same model is selected for both A and B + model_a_id = args.model_a_name if args.model_a_name is not None else MODEL_LIST[args.model_a_index] + model_b_id = args.model_b_name if args.model_b_name is not None else MODEL_LIST[args.model_b_index] + + if model_a_id == model_b_id: + raise ValueError(f"Cannot use the same model for both A and B: {model_a_id}") + +def save_data(data, filename): + """Saves data to a torch file.""" + os.makedirs(os.path.dirname(filename), exist_ok=True) + torch.save(data.cpu(), filename) + print(f"Data saved to {filename}") + +def load_data(filename): + """Loads data from a torch file.""" + return torch.load(filename) + +def get_llm_first_layer_embeddings(decoded_tokens_list, llm_embedding_tokenizer, llm_embedding_model, max_seq_length_embedding, device, batch_size=32): + """Generates embeddings using the first layer of a given LLM.""" + all_embeddings = [] + llm_embedding_model.eval() + embedding_dim = llm_embedding_model.config.hidden_size + + for i in tqdm(range(0, len(decoded_tokens_list), batch_size), desc="Encoding tokens with LLM"): + batch_tokens = decoded_tokens_list[i:i + batch_size] + inputs = llm_embedding_tokenizer( + batch_tokens, return_tensors="pt", padding=True, truncation=True, + max_length=max_seq_length_embedding, add_special_tokens=False, + ).to(device) + + with torch.no_grad(): + outputs = llm_embedding_model(**inputs, output_hidden_states=True) + first_layer_output = outputs.hidden_states[0] + + for k in range(first_layer_output.shape[0]): + valid_token_mask = inputs['attention_mask'][k] == 1 + if valid_token_mask.sum() > 0: + pooled_embedding = first_layer_output[k, valid_token_mask].mean(dim=0) + all_embeddings.append(pooled_embedding) + else: + all_embeddings.append(torch.zeros(embedding_dim, device=device)) + + return torch.stack(all_embeddings).to(device) + + +def compute_chunked_projection_map(embeddings_query, embeddings_corpus, args, device, chunk_size=1000): + """Computes projection map in chunks to save memory.""" + num_queries = embeddings_query.shape[0] + target_vocab_size = embeddings_corpus.shape[0] + + # Pre-allocate result tensors + all_top_k_indices = torch.zeros((num_queries, args.top_k), dtype=torch.long) + all_top_k_likelihoods = torch.zeros((num_queries, args.top_k), dtype=torch.float32) + + # Normalize corpus embeddings once + embeddings_corpus_norm = torch.nn.functional.normalize(embeddings_corpus.to(device).float(), p=2, dim=1) + + for chunk_start in tqdm(range(0, num_queries, chunk_size), desc="Processing chunks"): + chunk_end = min(chunk_start + chunk_size, num_queries) + chunk_query = embeddings_query[chunk_start:chunk_end].to(device).float() + + with torch.no_grad(): + # Compute similarities for this chunk + chunk_query_norm = torch.nn.functional.normalize(chunk_query, p=2, dim=1) + similarities = torch.matmul(chunk_query_norm, embeddings_corpus_norm.t()) + + # Generate projection map for this chunk + chunk_top_k_indices, chunk_top_k_likelihoods = generate_projection_map_chunk(similarities, args) + + # Store results + all_top_k_indices[chunk_start:chunk_end] = chunk_top_k_indices.cpu() + all_top_k_likelihoods[chunk_start:chunk_end] = chunk_top_k_likelihoods.cpu() + + # Clear GPU memory + del similarities, chunk_query_norm, chunk_top_k_indices, chunk_top_k_likelihoods + torch.cuda.empty_cache() + + return all_top_k_indices, all_top_k_likelihoods + +def generate_projection_map_chunk(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix chunk.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Normalize rows + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + threshold_mask = top_k_likelihood >= args.weight_threshold + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + return top_k_indices, top_k_likelihood + +def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + +def debug_projection_map(top_k_indices, top_k_likelihood, source_tokenizer, target_tokenizer, direction="", N=2000): + """Debug function to show first N rows with decoded tokens and weights.""" + N = min(N, top_k_indices.shape[0]) # Show first N rows or less + print(f"\n--- Debugging projection map {direction} (first {N} rows) ---") + + for row_idx in range(N): + # for row_idx in range(-N,-1): + # Decode source token + try: + token_id = row_idx if row_idx >= 0 else top_k_indices.shape[0] + row_idx + source_token = source_tokenizer.decode([token_id]) + # source_token = source_tokenizer.convert_ids_to_tokens([token_id])[0] + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Build the target tokens with weights string + row_indices = top_k_indices[row_idx].cpu().numpy() + row_weights = top_k_likelihood[row_idx].float().cpu().numpy() + + weight_total = 0 + target_parts = [] + + if row_weights.max() != row_weights[-1]: + continue + + + for target_idx, weight in zip(row_indices, row_weights): + try: + target_token = target_tokenizer.decode([target_idx]) + target_token_str = repr(target_token) + except: + target_token_str = f"" + + + target_parts.append(f"{target_token_str}({weight:.4f})") + weight_total += weight + + target_string = " ".join(target_parts) + # print(f"Weight total: {weight_total:.4f}") + print(f"{source_token_str} -> {target_string}") + +def generate_projection_map(similarities, args): + """Calculates the sparse likelihood map from a similarity matrix.""" + similarities = similarities.abs() + similarities[similarities > 0.999999999] = 1.0 + max_similarities = torch.max(similarities, dim=1, keepdim=True)[0] + sharpness = 10.0 * max_similarities + likelihood = similarities ** sharpness + + # Create a sparse representation by keeping only top-k values + # top_k_likelihood_pre_norm, _ = likelihood.topk(args.top_k, dim=1) + # likelihood = likelihood.where(likelihood >= top_k_likelihood_pre_norm[:, -1:], torch.zeros_like(likelihood)) + + # Normalize the row to sum to 1, handling rows that are all zero + # row_sums = likelihood.sum(dim=1, keepdim=True) + # safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + # likelihood = likelihood / safe_row_sums + # pdb.set_trace() + # likelihood = sinkhorn_one_dim(likelihood) + + # Get the final top-k values and their indices from the sparse, normalized likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Store top-k values before zeroing (to avoid losing them) + row_indices = torch.arange(likelihood.shape[0]).unsqueeze(1).expand(-1, args.top_k) + top_k_values = likelihood[row_indices, top_k_indices].clone() + + # Zero out entire likelihood matrix in-place, then restore only top-k elements + likelihood.zero_() + likelihood[row_indices, top_k_indices] = top_k_values + + # likelihood = sinkhorn(likelihood, n_iters=1) + # likelihood = sinkhorn(likelihood, n_iters=1) works the best + + + likelihood = sinkhorn_one_dim(likelihood) + + # Extract final top-k values from the normalized sparse likelihood matrix + top_k_likelihood, top_k_indices = likelihood.topk(args.top_k, dim=1) + + # Apply weight threshold filtering if specified + if args.weight_threshold > 0.0: + print(f"Applying weight threshold filter: {args.weight_threshold}") + # Create mask for values above threshold + # pdb.set_trace() + threshold_mask = top_k_likelihood >= args.weight_threshold + + #set indices to -1 where threshold is not met + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # # Count how many values per row are above threshold + # valid_counts = threshold_mask.sum(dim=1) + # total_filtered = (valid_counts == 0).sum().item() + # total_kept = threshold_mask.sum().item() + # total_possible = top_k_likelihood.numel() + + # print(f"Kept {total_kept}/{total_possible} ({100*total_kept/total_possible:.1f}%) projections above threshold") + + # if total_filtered > 0: + # print(f"Warning: {total_filtered} tokens have no projections above threshold {args.weight_threshold}") + + # # Zero out values below threshold + # filtered_likelihood = top_k_likelihood * threshold_mask.to(top_k_likelihood.dtype) + # filtered_indices = top_k_indices.clone() + + # # For rows with no values above threshold, keep the top value to avoid empty rows + # empty_rows = valid_counts == 0 + # if empty_rows.any(): + # print(f"Keeping top projection for {empty_rows.sum().item()} tokens with no values above threshold") + # filtered_likelihood[empty_rows, 0] = top_k_likelihood[empty_rows, 0] + + # top_k_likelihood = filtered_likelihood + # top_k_indices = filtered_indices + + # pdb.set_trace() + + return top_k_indices, top_k_likelihood + +# --- Main Execution --- +if __name__ == "__main__": + # Validate model selection arguments + validate_model_selection(args) + + # 1. Load Tokenizers and deterministically assign A and B + # Use model names if provided, otherwise use indices + if args.model_a_name is not None: + model_1 = {'id': args.model_a_name} + print(f"Using provided model A name: {args.model_a_name}") + else: + model_1 = {'id': MODEL_LIST[args.model_a_index]} + print(f"Using model A from index {args.model_a_index}: {model_1['id']}") + + model_1['name'] = model_1['id'].split("/")[-1] + print(f"Loading first tokenizer: {model_1['name']}") + model_1['tokenizer'] = load_tokenizer(model_1['id']) + + if args.model_b_name is not None: + model_2 = {'id': args.model_b_name} + print(f"Using provided model B name: {args.model_b_name}") + else: + model_2 = {'id': MODEL_LIST[args.model_b_index]} + print(f"Using model B from index {args.model_b_index}: {model_2['id']}") + + model_2['name'] = model_2['id'].split("/")[-1] + print(f"Loading second tokenizer: {model_2['name']}") + model_2['tokenizer'] = load_tokenizer(model_2['id']) + + # Deterministically assign model_A and model_B based on alphabetical order of names + if model_1['name'] > model_2['name']: + model_A, model_B = model_2, model_1 + else: + model_A, model_B = model_1, model_2 + + print(f"\nAssigned Source (A): {model_A['name']}") + print(f"Assigned Target (B): {model_B['name']}") + + source_vocab_size = model_A['tokenizer'].vocab_size + target_vocab_size = model_B['tokenizer'].vocab_size + # get the top k tokens from the source and target vocab from model config file + model_A_config = AutoConfig.from_pretrained(model_A['id'], trust_remote_code=True if 'nvidia' in model_A['id'] else False) + model_B_config = AutoConfig.from_pretrained(model_B['id'], trust_remote_code=True if 'nvidia' in model_B['id'] else False) + # pdb.set_trace() + source_vocab_size = model_A_config.vocab_size + if "gemma" not in model_B['id']: + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Source vocab size (full): {source_vocab_size}") + print(f"Target vocab size (full): {target_vocab_size}") + # exit() + + + + if 0: + # just debugging learned projection map + # learned_projection_map = torch.load("models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_transformation_matrices/learned_projection_map_latest.pt") + # learned_projection_map = torch.load("cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_64_double.pt") + learned_projection_map = torch.load("cross_tokenizer_data/projection_matrix_learned_llama_qwen_top5.pt") + top_k_indices_A_to_B = learned_projection_map["indices"] + top_k_likelihood_A_to_B = learned_projection_map["likelihoods"] + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B", N=150000) + exit() + + + + + + + + # 2. Select and Load Embedding Model + embedding_model_index = 3 # Default to a good LLM embedder + selected_model_info = EMBEDDING_MODEL_CHOICES[embedding_model_index] + embedding_model_name = selected_model_info["name"] + embedding_model_type = selected_model_info["type"] + print(f"\nUsing embedding model: {embedding_model_name} ({embedding_model_type})") + + # 3. Generate or Load Embeddings + canonicalization_suffix = "_canonical" if args.use_canonicalization else "_raw" + embeddings_path_A = os.path.join(args.data_dir, f"embeddings_{model_A['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + embeddings_path_B = os.path.join(args.data_dir, f"embeddings_{model_B['name']}_{embedding_model_name.replace('/', '_')}_full{canonicalization_suffix}.pt") + + if not args.force_recompute and os.path.exists(embeddings_path_A) and os.path.exists(embeddings_path_B): + print("Loading cached embeddings...") + model_A['embeddings'] = load_data(embeddings_path_A).to(DEVICE) + model_B['embeddings'] = load_data(embeddings_path_B).to(DEVICE) + else: + print("Generating new embeddings...") + + # Generate raw decoded tokens + raw_tokens_A = [model_A['tokenizer'].decode([idx]) for idx in range(model_A['tokenizer'].vocab_size)] + raw_tokens_B = [model_B['tokenizer'].decode([idx]) for idx in range(model_B['tokenizer'].vocab_size)] + + # Apply canonicalization if requested + if args.use_canonicalization: + # Import canonicalization function + import sys + sys.path.append('.') + from tokenalign import TokenAligner + + print("Applying token canonicalization before embedding generation...") + decoded_tokens_A = [TokenAligner._canonical_token(token) for token in raw_tokens_A] + decoded_tokens_B = [TokenAligner._canonical_token(token) for token in raw_tokens_B] + + # Show some examples of canonicalization + print("Canonicalization examples:") + for i in range(min(10, len(raw_tokens_A))): + if raw_tokens_A[i] != decoded_tokens_A[i]: + print(f" Model A: '{raw_tokens_A[i]}' -> '{decoded_tokens_A[i]}'") + for i in range(min(10, len(raw_tokens_B))): + if raw_tokens_B[i] != decoded_tokens_B[i]: + print(f" Model B: '{raw_tokens_B[i]}' -> '{decoded_tokens_B[i]}'") + + print(f"Applied canonicalization to {len(decoded_tokens_A)} tokens for model A and {len(decoded_tokens_B)} tokens for model B") + else: + print("Using raw decoded tokens without canonicalization") + decoded_tokens_A = raw_tokens_A + decoded_tokens_B = raw_tokens_B + + if embedding_model_type == "sbert": + sbert_model = SentenceTransformer(embedding_model_name, device=DEVICE) + model_A['embeddings'] = sbert_model.encode(decoded_tokens_A, convert_to_tensor=True, show_progress_bar=True) + model_B['embeddings'] = sbert_model.encode(decoded_tokens_B, convert_to_tensor=True, show_progress_bar=True) + elif embedding_model_type == "llm_first_layer": + llm_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, trust_remote_code=True) + if llm_tokenizer.pad_token is None: llm_tokenizer.pad_token = llm_tokenizer.eos_token + llm_model = AutoModel.from_pretrained( + embedding_model_name, torch_dtype=torch.bfloat16, trust_remote_code=True + ).to(DEVICE) + model_A['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_A, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + model_B['embeddings'] = get_llm_first_layer_embeddings(decoded_tokens_B, llm_tokenizer, llm_model, MAX_SEQ_LENGTH_EMBEDDING, DEVICE) + + save_data(model_A['embeddings'], embeddings_path_A) + save_data(model_B['embeddings'], embeddings_path_B) + + # 4. Compute Similarity and Generate Projection Maps (chunked to save memory) + print("\nComputing projection map in chunks to save memory...") + chunk_size = 500 # Process 500 tokens at a time to avoid OOM + top_k_indices_A_to_B, top_k_likelihood_A_to_B = compute_chunked_projection_map( + model_A['embeddings'], model_B['embeddings'], args, DEVICE, chunk_size=chunk_size) + + # Note: Exact match enforcement is skipped in chunked mode for simplicity + # The chunked approach processes similarities in small batches to avoid OOM + if 0: + debug_projection_map(top_k_indices_A_to_B, top_k_likelihood_A_to_B, model_A['tokenizer'], model_B['tokenizer'], "A -> B") + + # print("Generating B -> A projection map...") + # top_k_indices_B_to_A, top_k_likelihood_B_to_A = generate_projection_map(similarities.T, args) + # debug_projection_map(top_k_indices_B_to_A, top_k_likelihood_B_to_A, model_B['tokenizer'], model_A['tokenizer'], "B -> A") + + # 5. Save the Combined Projection Map + print("\nSaving combined projection map...") + model_a_clean_name = clean_model_name_for_filename(model_A['name']) + model_b_clean_name = clean_model_name_for_filename(model_B['name']) + # output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_bidirectional_top_{args.top_k}.pt" + output_filename = f"temp_projection_map_{model_a_clean_name}_to_{model_b_clean_name}_top_{args.top_k}" + # if args.skip_exact_enforcement: + # output_filename += "_no_exact" + output_filename += f".pt" + if args.weight_threshold > 0.0: + output_filename = output_filename.replace(".pt", f"_thresh_{args.weight_threshold:.3f}.pt") + output_path = os.path.join(args.data_dir, output_filename) + + torch.save({ + "indices": top_k_indices_A_to_B.cpu(), + "likelihoods": top_k_likelihood_A_to_B.cpu(), + "model_A_id": model_A['id'], + "model_B_id": model_B['id'], + }, output_path) + + # torch.save({ + # "A_to_B": { + # "indices": top_k_indices_A_to_B.cpu(), + # "likelihoods": top_k_likelihood_A_to_B.cpu() + # }, + # "B_to_A": { + # "indices": top_k_indices_B_to_A.cpu(), + # "likelihoods": top_k_likelihood_B_to_A.cpu() + # }, + # "model_A_id": model_A['id'], + # "model_B_id": model_B['id'], + # }, output_path) + print(f"Saved combined projection map to: {output_path}") + + # 6. Example Usage of the Projection Function + print("\n--- Testing projection function (A -> B) ---") + # Create a dummy likelihood tensor: [BATCH, SEQ, vocab_size_A] + source_vocab_size_A = model_A['embeddings'].shape[0] + target_vocab_size_B = model_B['embeddings'].shape[0] + dummy_tensor = torch.randn(1, 4096, source_vocab_size_A, device=DEVICE, dtype=torch.bfloat16) + + # Transform this tensor using the projection map (convert to float32 for compatibility) + projected_tensor = project_token_likelihoods(dummy_tensor.float(), top_k_indices_A_to_B, top_k_likelihood_A_to_B, target_vocab_size_B, DEVICE) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful.") diff --git a/nemo_rl/algorithms/x_token/minimal_projection_via_multitoken.py b/nemo_rl/algorithms/x_token/minimal_projection_via_multitoken.py new file mode 100644 index 0000000000..ad993efc29 --- /dev/null +++ b/nemo_rl/algorithms/x_token/minimal_projection_via_multitoken.py @@ -0,0 +1,929 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + + + +###### save as dense format and set indices to -1 where not used + + +#remove all special tokens that start with <| and end with |> + + +# compare 3 ways to estimate likelihood matrix: +# 1. using embeddings from another model, like was done in minimal_projection_generator.py +# 2. using text analysis like in tokenalign_likelihood_estimate.py +# 3. use one token to multiple and assign those as transformation matrix + +# this file implements 3rd way + +def sinkhorn_one_dim(A, n_iters=1): + for _ in range(n_iters): + + + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + + return A + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def create_weight_distribution(num_tokens): + """Create weight distribution for multi-token mappings""" + # if num_tokens == 1: + # return [1.0] + # elif num_tokens == 2: + # return [0.7, 0.3] + # elif num_tokens == 3: + # return [0.6, 0.3, 0.1] + # else: + if 1: + # For more tokens, use exponential decay + weights = [] + base = 0.9 + for i in range(num_tokens): + if i == 0: + weights.append(base) + else: + weights.append(base * (0.1 ** i)) + + # Normalize to sum to 1 + total = sum(weights) + weights = [w / total for w in weights] + return weights + +def find_similar_special_tokens(tokenizer_a, tokenizer_b, similarity_threshold=0.4, top_k_matches=3): + """Find similar special tokens between two tokenizers using string similarity.""" + + def is_special_token(token_str): + """Check if a token looks like a special token""" + return (token_str.startswith('<|') and token_str.endswith('|>')) or \ + (token_str.startswith('<') and token_str.endswith('>')) or \ + token_str in ['', '', '', '', '', ''] + + def extract_special_tokens(tokenizer): + """Extract all special tokens from a tokenizer with their IDs""" + special_tokens = {} + vocab = tokenizer.get_vocab() + for token_str, token_id in vocab.items(): + if is_special_token(token_str): + special_tokens[token_id] = token_str + return special_tokens + + def calculate_similarity(token_a, token_b): + """Calculate similarity between two token strings""" + # Use difflib for sequence similarity + seq_similarity = difflib.SequenceMatcher(None, token_a, token_b).ratio() + + # Extract key words from special tokens for semantic matching + def extract_keywords(token): + # Remove special token markers and split by common separators + cleaned = re.sub(r'[<>|_]', ' ', token.lower()) + words = [w for w in cleaned.split() if len(w) > 2] # Filter short words + return set(words) + + keywords_a = extract_keywords(token_a) + keywords_b = extract_keywords(token_b) + + # Jaccard similarity for keywords + if keywords_a or keywords_b: + keyword_similarity = len(keywords_a.intersection(keywords_b)) / len(keywords_a.union(keywords_b)) + else: + keyword_similarity = 0.0 + + # Combined similarity (weighted average) + return 0.6 * seq_similarity + 0.4 * keyword_similarity + + print("Extracting special tokens...") + special_tokens_a = extract_special_tokens(tokenizer_a) # student + special_tokens_b = extract_special_tokens(tokenizer_b) # teacher + + print(f"Found {len(special_tokens_a)} special tokens in student tokenizer") + print(f"Found {len(special_tokens_b)} special tokens in teacher tokenizer") + + # Find matches + special_token_mappings = [] + + print("Finding similar special tokens...") + for token_id_a, token_str_a in special_tokens_a.items(): + similarities = [] + for token_id_b, token_str_b in special_tokens_b.items(): + similarity = calculate_similarity(token_str_a, token_str_b) + if similarity >= similarity_threshold: + similarities.append((token_id_b, token_str_b, similarity)) + + # Sort by similarity and take top-k + similarities.sort(key=lambda x: x[2], reverse=True) + for token_id_b, token_str_b, similarity in similarities[:top_k_matches]: + special_token_mappings.append({ + 'student_id': token_id_a, + 'student_token': token_str_a, + 'teacher_id': token_id_b, + 'teacher_token': token_str_b, + 'similarity': similarity + }) + + return special_token_mappings + + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + + # Configuration from arguments + ENABLE_SCALE_TRICK = args.enable_scale_trick + ENABLE_REVERSE_PASS = args.enable_reverse_pass + ENABLE_EXACT_MATCH = args.enable_exact_match + + TOKENS_TO_CUT = args.tokens_to_cut + TOP_K = args.top_k + USE_RAW_TOKENS = args.use_raw_tokens + INITIAL_PROJECTION_PATH = args.initial_projection_path + ENABLE_SPECIAL_TOKEN_MAPPING = args.enable_special_token_mapping + SPECIAL_TOKEN_SIMILARITY_THRESHOLD = args.special_token_similarity_threshold + SPECIAL_TOKEN_TOP_K = args.special_token_top_k if args.special_token_top_k is not None else TOP_K + USE_CANONICALIZATION = args.use_canonicalization + + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + + # Print configuration + print("=== Configuration ===") + print(f"Student model: {student_model_name}") + print(f"Teacher model: {teacher_model_name}") + print(f"Enable scale trick: {ENABLE_SCALE_TRICK}") + print(f"Enable reverse pass: {ENABLE_REVERSE_PASS}") + print(f"Enable exact match: {ENABLE_EXACT_MATCH}") + print(f"Use raw tokens: {USE_RAW_TOKENS}") + print(f"Use canonicalization: {USE_CANONICALIZATION}") + print(f"Tokens to cut: {TOKENS_TO_CUT}") + print(f"Top K: {TOP_K}") + print(f"Enable special token mapping: {ENABLE_SPECIAL_TOKEN_MAPPING}") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"Special token similarity threshold: {SPECIAL_TOKEN_SIMILARITY_THRESHOLD}") + print(f"Special token top K: {SPECIAL_TOKEN_TOP_K}") + print(f"Initial projection path: {INITIAL_PROJECTION_PATH}") + print(f"Output directory: {args.output_dir}") + print("=" * 25) + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + # pdb.set_trace() + if "gemma" not in student_model_name.lower(): + source_vocab_size = model_A_config.vocab_size + else: + source_vocab_size = model_A_config.text_config.vocab_size + + if "gemma" not in teacher_model_name.lower(): + target_vocab_size = model_B_config.vocab_size + else: + target_vocab_size = model_B_config.text_config.vocab_size + + tokenizer_student_total_vocab_size = source_vocab_size + tokenizer_teacher_total_vocab_size = target_vocab_size + # print(f"Source top k tokens: {model_A_top_k_tokens}") + # print(f"Target top k tokens: {model_B_top_k_tokens}") + + print(f"Student tokenizer total vocab size: {tokenizer_student_total_vocab_size}") + print(f"Teacher tokenizer total vocab size: {tokenizer_teacher_total_vocab_size}") + + # Print token processing mode + if USE_RAW_TOKENS: + print("Using raw token representation (convert_ids_to_tokens)") + else: + print("Using decoded token representation (decode)") + + + transformation_counts = defaultdict(float) + import os + if INITIAL_PROJECTION_PATH and os.path.exists(INITIAL_PROJECTION_PATH): + print(f"Loading initial projection from: {INITIAL_PROJECTION_PATH}") + initial_projection_map = torch.load(INITIAL_PROJECTION_PATH, map_location='cpu') + + if isinstance(initial_projection_map, dict) and 'indices' in initial_projection_map and 'likelihoods' in initial_projection_map: + print("Loading from sparse top-k format and converting to transformation_counts.") + indices = initial_projection_map['indices'] + likelihoods = initial_projection_map['likelihoods'] + + loaded_student_model = initial_projection_map.get('model_A_id') + loaded_teacher_model = initial_projection_map.get('model_B_id') + + if loaded_student_model and loaded_student_model != student_model_name: + print(f"Warning: Student model mismatch. Loaded: {loaded_student_model}, Current: {student_model_name}") + if loaded_teacher_model and loaded_teacher_model != teacher_model_name: + print(f"Warning: Teacher model mismatch. Loaded: {loaded_teacher_model}, Current: {teacher_model_name}") + + num_student_tokens = indices.shape[0] + top_k = indices.shape[1] + + for student_id in tqdm.tqdm(range(num_student_tokens), desc="Converting initial projection to counts"): + for k in range(top_k): + teacher_id = indices[student_id, k].item() + if teacher_id != -1: + likelihood = likelihoods[student_id, k].item() + if likelihood > 0: + transformation_counts[(student_id, teacher_id)] = likelihood + + elif torch.is_tensor(initial_projection_map): + if initial_projection_map.is_sparse: + print("Loading from sparse tensor and converting to transformation_counts.") + sparse_matrix = initial_projection_map.coalesce() + map_indices = sparse_matrix.indices() + map_values = sparse_matrix.values() + for i in tqdm.tqdm(range(map_indices.shape[1]), desc="Converting sparse tensor to counts"): + student_id = map_indices[0, i].item() + teacher_id = map_indices[1, i].item() + weight = map_values[i].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print("Loading from dense matrix and converting to transformation_counts.") + dense_matrix = initial_projection_map + non_zero_indices = torch.nonzero(dense_matrix, as_tuple=False) + for idx in tqdm.tqdm(range(non_zero_indices.shape[0]), desc="Converting dense projection to counts"): + student_id = non_zero_indices[idx, 0].item() + teacher_id = non_zero_indices[idx, 1].item() + weight = dense_matrix[student_id, teacher_id].item() + if weight > 0: + transformation_counts[(student_id, teacher_id)] = weight + else: + print(f"Warning: Unrecognized format for initial projection map at {INITIAL_PROJECTION_PATH}. Skipping.") + + print(f"Initialized transformation_counts with {len(transformation_counts)} entries.") + + # pdb.set_trace() + + ignore_tokens = ['<|endoftext|>', '', ] + ignore_student_ids = {tokenizer_student.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_student.get_vocab()} + ignore_teacher_ids = {tokenizer_teacher.convert_tokens_to_ids(token) for token in ignore_tokens if token in tokenizer_teacher.get_vocab()} + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + # Find multi-token mappings + multi_token_examples = [] + + + print("=== FIRST PASS: Teacher tokens -> Student tokens ===") + print("Finding multi-token mappings...") + + + # First pass: Teacher tokens -> Student tokens (reverse direction) + if 1: + print("\n=== First PASS: Student tokens -> Teacher tokens ===") + + # Get all student tokens and decode them + student_vocab = tokenizer_student.get_vocab() + + student_tokens_decoded = {} + + print("Decoding student tokens...") + for token_id in tqdm.tqdm(range(tokenizer_student_total_vocab_size), desc="Decoding student tokens"): + if token_id in ignore_student_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_student.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_student.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + student_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(student_tokens_decoded)} student tokens") + + reverse_multi_token_examples = [] + print("Finding reverse multi-token mappings...") + for student_token_id, student_token_str in tqdm.tqdm(student_tokens_decoded.items(), desc="Processing student tokens"): + # Tokenize the student token string using teacher tokenizer + teacher_encoding = tokenizer_teacher(student_token_str, add_special_tokens=False, return_attention_mask=False) + teacher_token_ids = teacher_encoding['input_ids'] + + # Skip if any teacher token is in ignore list + if any(tid in ignore_teacher_ids for tid in teacher_token_ids): + continue + + # Cut to only first 4 tokens + teacher_token_ids = teacher_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of teacher tokens + weights = create_weight_distribution(len(teacher_token_ids)) + + # Add to transformation matrix (reverse direction: teacher_token_id -> student_token_id) + if 1: + for teacher_token_id, weight in zip(teacher_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(teacher_token_ids) >= 2: + teacher_tokens_decoded_reverse = [tokenizer_teacher.decode([tid]) for tid in teacher_token_ids] + reverse_multi_token_examples.append({ + 'student_token': student_token_str, + 'student_id': student_token_id, + 'teacher_tokens': teacher_tokens_decoded_reverse, + 'teacher_ids': teacher_token_ids, + 'weights': weights + }) + + # second pass: Teacher tokens -> Student tokens (opposite direction) + if ENABLE_REVERSE_PASS: + print("\n=== secod PASS: Teacher tokens -> Student tokens ===") + + # Get all teacher tokens and decode them + teacher_vocab = tokenizer_teacher.get_vocab() + teacher_tokens_decoded = {} + + print("Decoding teacher tokens...") + for token_id in tqdm.tqdm(range(tokenizer_teacher_total_vocab_size), desc="Decoding teacher tokens"): + if token_id in ignore_teacher_ids: + continue + try: + # Get token representation based on configuration + if USE_RAW_TOKENS: + decoded = tokenizer_teacher.convert_ids_to_tokens([token_id])[0] + else: + decoded = tokenizer_teacher.decode([token_id]) + + if decoded.startswith("<|") and decoded.endswith("|>"): + print(f"Skipping special token: {decoded}") + continue + + # Apply canonicalization if enabled + decoded = apply_canonicalization_if_enabled(decoded, USE_CANONICALIZATION) + teacher_tokens_decoded[token_id] = decoded + except: + # Skip tokens that can't be processed + continue + + print(f"Successfully decoded {len(teacher_tokens_decoded)} teacher tokens") + + teacher_to_student_multi_token_examples = [] + print("Finding teacher->student multi-token mappings...") + for teacher_token_id, teacher_token_str in tqdm.tqdm(teacher_tokens_decoded.items(), desc="Processing teacher tokens"): + # Tokenize the teacher token string using student tokenizer + student_encoding = tokenizer_student(teacher_token_str, add_special_tokens=False, return_attention_mask=False) + student_token_ids = student_encoding['input_ids'] + + # Skip if any student token is in ignore list + if any(sid in ignore_student_ids for sid in student_token_ids): + continue + + # Cut to only first 4 tokens + student_token_ids = student_token_ids[:TOKENS_TO_CUT] + + # Get weight distribution based on number of student tokens + weights = create_weight_distribution(len(student_token_ids)) + + # Add to transformation matrix (student_token_id -> teacher_token_id mapping) + if 1: + for student_token_id, weight in zip(student_token_ids, weights): + transformation_counts[(student_token_id, teacher_token_id)] += weight + + # Collect examples for analysis + if len(student_token_ids) >= 2: + student_tokens_decoded_reverse = [tokenizer_student.decode([sid]) for sid in student_token_ids] + teacher_to_student_multi_token_examples.append({ + 'teacher_token': teacher_token_str, + 'teacher_id': teacher_token_id, + 'student_tokens': student_tokens_decoded_reverse, + 'student_ids': student_token_ids, + 'weights': weights + }) + + print(f"\n=== ADDING SPECIAL TOKEN MAPPINGS ===") + + # Find and add special token mappings (if enabled) + special_token_mappings = [] + if ENABLE_SPECIAL_TOKEN_MAPPING: + special_token_mappings = find_similar_special_tokens( + tokenizer_student, + tokenizer_teacher, + similarity_threshold=SPECIAL_TOKEN_SIMILARITY_THRESHOLD, + top_k_matches=SPECIAL_TOKEN_TOP_K + ) + else: + print("Special token mapping disabled") + + if special_token_mappings: + print(f"\nFound {len(special_token_mappings)} special token mappings:") + initial_transformation_count = len(transformation_counts) + + # Add ALL mappings to transformation matrix + for mapping in special_token_mappings: + student_id = mapping['student_id'] + teacher_id = mapping['teacher_id'] + similarity = mapping['similarity'] + + # Add mapping with weight based on similarity + weight = similarity * 0.8 # Scale similarity to reasonable weight + transformation_counts[(student_id, teacher_id)] += weight + + # Group mappings by student token and show top 2 matches per student token + from collections import defaultdict + student_mappings = defaultdict(list) + for mapping in special_token_mappings: + student_mappings[mapping['student_id']].append(mapping) + + # Sort each student's mappings by similarity and show top 2 + print("Top 2 matches per student special token:") + shown_count = 0 + for student_id, mappings in student_mappings.items(): + # Sort by similarity (highest first) + sorted_mappings = sorted(mappings, key=lambda x: x['similarity'], reverse=True) + + # Show top 2 for this student token + student_token = sorted_mappings[0]['student_token'] # Get student token name + print(f" {student_token}:") + + for mapping in sorted_mappings[:2]: + similarity = mapping['similarity'] + weight = similarity * 0.8 + print(f" -> '{mapping['teacher_token']}' (similarity: {similarity:.3f}, weight: {weight:.3f})") + shown_count += 1 + + if len(sorted_mappings) > 2: + print(f" ... and {len(sorted_mappings) - 2} more matches") + + total_hidden = len(special_token_mappings) - shown_count + if total_hidden > 0: + print(f"Total mappings not shown: {total_hidden}") + + added_count = len(transformation_counts) - initial_transformation_count + print(f"Added {added_count} new special token transformation entries") + else: + print("No similar special tokens found") + + print(f"\n=== SUMMARY ===") + print(f"Found {len(multi_token_examples)} teacher tokens that map to multiple student tokens") + # exit() + # Show some examples + if multi_token_examples: + print("\nExamples of multi-token mappings:") + for i, example in enumerate(multi_token_examples[:10]): + print(f" Teacher '{example['teacher_token']}' -> Student {example['student_tokens']} (weights: {example['weights']})") + if len(multi_token_examples) > 10: + print(f" ... and {len(multi_token_examples) - 10} more.") + + if ENABLE_REVERSE_PASS: + print(f"\nReverse pass enabled - added bidirectional mappings") + + print(f"\nTotal transformation entries: {len(transformation_counts)}") + + if ENABLE_EXACT_MATCH: + + print("Checking for exact token matches and setting exact mappings...") + # check exact match between student and teacher tokens and set those as perfect 1-to-1 mappings + # Convert all tokens to strings at once for vectorized comparison + # pdb.set_trace() + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # For tokens that match exactly, we want their mapping to be 1.0 + # and they should not be mapped to any other token. + # First, remove all existing mappings for these student tokens + match_indices_student_set = set(match_indices_student) + keys_to_remove = [] + for key in transformation_counts.keys(): + student_id, teacher_id = key + if student_id in match_indices_student_set: + keys_to_remove.append(key) + for key in keys_to_remove: + del transformation_counts[key] + # Then, set the perfect 1-to-1 mappings for exact matches + for student_id, teacher_id in zip(match_indices_student, match_indices_teacher): + transformation_counts[(student_id, teacher_id)] = 1.0 + + + def debug_projection_map(transformation_counts, source_tokenizer, target_tokenizer, direction="", N=50): + """Debug function to show projection mappings with decoded tokens and weights.""" + print(f"\n--- Debugging projection map {direction} (showing {N} examples) ---") + + # Group transformation_counts by source token (student token) + source_to_targets = defaultdict(list) + for (source_id, target_id), weight in transformation_counts.items(): + source_to_targets[source_id].append((target_id, weight)) + + # Sort by source token ID and take first N + # sorted_sources = sorted(source_to_targets.keys())[:N] + sorted_sources = sorted(source_to_targets.keys())[-N:] + + for source_id in sorted_sources: + # Decode source token + try: + if USE_RAW_TOKENS: + source_token = source_tokenizer.convert_ids_to_tokens([source_id])[0] + else: + source_token = source_tokenizer.decode([source_id]) + source_token = apply_canonicalization_if_enabled(source_token, USE_CANONICALIZATION) + source_token_str = repr(source_token) # Use repr to show special chars + except: + source_token_str = f"" + + # Sort targets by weight (descending) and build target string + targets_weights = sorted(source_to_targets[source_id], key=lambda x: x[1], reverse=True) + + target_parts = [] + for target_id, weight in targets_weights: + try: + if USE_RAW_TOKENS: + target_token = target_tokenizer.convert_ids_to_tokens([target_id])[0] + else: + target_token = target_tokenizer.decode([target_id]) + target_token = apply_canonicalization_if_enabled(target_token, USE_CANONICALIZATION) + target_token_str = repr(target_token) + except: + target_token_str = f"" + target_parts.append(f"{target_token_str}({weight:.4f})") + + target_string = " ".join(target_parts) + print(f"{source_token_str} -> {target_string}") + + # debug_projection_map(transformation_counts, tokenizer_student, tokenizer_teacher, + # direction="student->teacher", N=1000) + + # Create transformation matrix (student -> teacher projection) + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + teacher_indices = [idx[1] for idx in indices] + student_indices = [idx[0] for idx in indices] + + # Create sparse tensor with student tokens as rows, teacher tokens as columns + # This creates a student -> teacher projection matrix + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values) + + transformation_matrix_sparse = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (tokenizer_student_total_vocab_size, tokenizer_teacher_total_vocab_size), + device="cuda" if torch.cuda.is_available() else "cpu", + dtype=torch.bfloat16 + ) + + # indices, values = torch.topk(transformation_matrix_sparse, k=1000, dim=1) + + print(f"Created sparse student->teacher projection matrix with shape: {transformation_matrix_sparse.shape}") + print(f"Non-zero elements: {transformation_matrix_sparse._nnz()}") + + if 0: + # cant fit to the memory + # Calculate mapping statistics from sparse matrix + print("\nCalculating mapping statistics from projection matrix...") + + # Count non-zero elements in each row (each row = student token) + dense_matrix = transformation_matrix_sparse.to_dense() + non_zero_counts_per_row = (dense_matrix != 0).sum(dim=1) # Count non-zeros per row + + # Create statistics + mapping_stats = defaultdict(int) + for count in non_zero_counts_per_row: + mapping_stats[count.item()] += 1 + + # Print mapping statistics + print("\nMapping statistics (student tokens -> teacher tokens):") + for i in range(1, 5): # 1, 2, 3, 4 teacher tokens + count = mapping_stats.get(i, 0) + print(f"Student tokens mapping to {i} teacher tokens: {count}") + + total_mapped = sum(mapping_stats.values()) + print(f"Total student tokens mapped: {total_mapped}") + + # Convert sparse matrix to same format as minimal_projection_generator.py + os.makedirs(args.output_dir, exist_ok=True) + + # Convert defaultdict to regular dict for saving + transformation_counts_dict = dict(transformation_counts) + + # Show some examples of the projection mappings + debug_projection_map(transformation_counts_dict, tokenizer_student, tokenizer_teacher, + direction="student->teacher", N=1000) + + # exit() + + print(f"\nConverting sparse matrix to top-{TOP_K} dense format...") + + # Convert sparse matrix to dense and get top-k values per row + print("Converting to dense matrix on CPU to avoid memory issues...") + dense_matrix = transformation_matrix_sparse.cpu().to_dense() # Move to CPU to handle memory + print(f"Dense matrix shape: {dense_matrix.shape}") + + # Get top-k values and indices for each row (each source token) + print(f"Extracting top-{TOP_K} values per token...") + + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + dense_matrix = sinkhorn_one_dim(dense_matrix, n_iters=1) + + # Extract top-k on CPU + top_k_likelihoods, top_k_indices = torch.topk(dense_matrix, k=min(TOP_K, dense_matrix.shape[1]), dim=1) + # exit() + # Handle case where vocabulary has fewer tokens than TOP_K + actual_k = top_k_indices.shape[1] + if actual_k < TOP_K: + print(f"Warning: Target vocabulary size ({dense_matrix.shape[1]}) is smaller than TOP_K ({TOP_K}). Using k={actual_k}") + # Pad with -1 indices and 0.0 likelihoods to maintain consistent shape + pad_size = TOP_K - actual_k + top_k_indices = torch.cat([top_k_indices, torch.full((top_k_indices.shape[0], pad_size), -1, dtype=top_k_indices.dtype)], dim=1) + top_k_likelihoods = torch.cat([top_k_likelihoods, torch.zeros((top_k_likelihoods.shape[0], pad_size), dtype=top_k_likelihoods.dtype)], dim=1) + + if 0: + threshold_mask = top_k_likelihoods >= 0.0000000000000000001 + top_k_indices = top_k_indices.where(threshold_mask, torch.full_like(top_k_indices, -1)) + + # Apply SCALE_TRICK: set last column to -4 if enabled + if ENABLE_SCALE_TRICK: + print("ENABLE_SCALE_TRICK is True: Setting last column of likelihoods to -4.0") + top_k_likelihoods[:, -1] = 0.2 + if ENABLE_EXACT_MATCH: + for indices in match_indices_student: + top_k_likelihoods[indices, -1] = 0.0 + print(f"Set last column of likelihoods to 0.0 for {len(match_indices_student)} exact matches as exact match is enabled") + # Apply sinkhorn normalization on CPU + if 1: + print("Applying Sinkhorn normalization on CPU...") + top_k_likelihoods = sinkhorn_one_dim(top_k_likelihoods, n_iters=1) + + # pdb.set_trace() + #set indices to -1 where likelihood is 0 + + # Create filename in same format as minimal_projection_generator.py + def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + + student_clean_name = clean_model_name_for_filename(student_model_name.split("/")[-1]) + teacher_clean_name = clean_model_name_for_filename(teacher_model_name.split("/")[-1]) + + output_filename = f"projection_map_{student_clean_name}_to_{teacher_clean_name}_multitoken_top_{TOP_K}_double" + # if USE_RAW_TOKENS: + # output_filename += "_raw_tokens" + if ENABLE_SPECIAL_TOKEN_MAPPING: + output_filename += "_special" + output_filename += ".pt" + # if ENABLE_REVERSE_PASS: + # output_filename = output_filename.replace(".pt", "_bidirectional.pt") + output_path = os.path.join(args.output_dir, output_filename) + + # Save in same format as minimal_projection_generator.py + torch.save({ + "indices": top_k_indices, + "likelihoods": top_k_likelihoods, + "model_A_id": student_model_name, # source model (student) + "model_B_id": teacher_model_name, # target model (teacher) + }, output_path) + + print(f"Saved projection map to: {output_path}") + print(f"Format: indices shape {top_k_indices.shape}, likelihoods shape {top_k_likelihoods.shape}") + print(f"Compatible with minimal_projection_generator.py format") + print(f"Token processing mode: {'Raw tokens (convert_ids_to_tokens)' if USE_RAW_TOKENS else 'Decoded tokens (decode)'}") + if ENABLE_REVERSE_PASS: + print("File includes bidirectional mappings (teacher->student and student->teacher)") + if ENABLE_SPECIAL_TOKEN_MAPPING: + print(f"File includes special token mappings (similarity_threshold={SPECIAL_TOKEN_SIMILARITY_THRESHOLD}, top_k={SPECIAL_TOKEN_TOP_K})") + # exit() + + + + + # Test projection function compatibility (same as minimal_projection_generator.py) + print("\n--- Testing projection function compatibility ---") + + def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device): + """Projects token likelihoods from a source to a target vocabulary using a sparse map.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if source_vocab_size != projection_map_indices.shape[0]: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + projection_map_indices = projection_map_indices.to(device) + projection_map_values = projection_map_values.to(device) + + crow_indices = torch.arange(0, (source_vocab_size + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size, target_vocab_size), device=device + ) + + reshaped_input = input_likelihoods.reshape(batch_size * seq_len, source_vocab_size) + projected_likelihoods_reshaped = torch.matmul(reshaped_input, sparse_projection_matrix) + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size) + + # Create a dummy likelihood tensor: [BATCH, SEQ, source_vocab_size] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_tensor = torch.randn(1, 4096, tokenizer_student_total_vocab_size, device=device, dtype=torch.bfloat16) + + # Transform this tensor using the projection map + projected_tensor = project_token_likelihoods( + dummy_tensor, + top_k_indices.to(device), + top_k_likelihoods.to(device), + tokenizer_teacher_total_vocab_size, + device + ) + print(f"Input tensor shape: {dummy_tensor.shape}") + print(f"Projected tensor shape: {projected_tensor.shape}") + print("Projection test successful - format is fully compatible!") + + # pdb.set_trace() + \ No newline at end of file diff --git a/nemo_rl/algorithms/x_token/reapply_exact_map.py b/nemo_rl/algorithms/x_token/reapply_exact_map.py new file mode 100644 index 0000000000..2537f8ec8a --- /dev/null +++ b/nemo_rl/algorithms/x_token/reapply_exact_map.py @@ -0,0 +1,233 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +import torch.nn as nn +from tokenalign import TokenAligner +import gc +from collections import defaultdict +from datasets import load_dataset, get_dataset_config_names +import random +import time +import numpy as np +import tqdm +import pdb +import difflib +import re +import argparse +import os + +def apply_canonicalization_if_enabled(token_str, use_canonicalization): + """Apply canonicalization to token string if enabled.""" + if use_canonicalization: + return TokenAligner._canonical_token(token_str) + return token_str + +def parse_arguments(): + """Parse command line arguments for the multi-token projection script.""" + parser = argparse.ArgumentParser( + description="Generate multi-token projection mappings between tokenizers", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Model selection arguments + parser.add_argument( + "--student-model", + type=str, + default="meta-llama/Llama-3.2-1B", + help="Student model name or path" + ) + parser.add_argument( + "--teacher-model", + type=str, + default="microsoft/phi-4", + help="Teacher model name or path" + ) + + # Boolean flags + parser.add_argument( + "--enable-scale-trick", + action="store_true", + default=True, + help="Enable scale trick (set last column likelihood to 0.2)" + ) + parser.add_argument( + "--disable-scale-trick", + action="store_false", + dest="enable_scale_trick", + help="Disable scale trick" + ) + parser.add_argument( + "--enable-reverse-pass", + action="store_true", + default=True, + help="Enable second pass: student tokens -> teacher tokens" + ) + parser.add_argument( + "--disable-reverse-pass", + action="store_false", + dest="enable_reverse_pass", + help="Disable reverse pass" + ) + parser.add_argument( + "--enable-exact-match", + action="store_true", + default=False, + help="Enable exact match enforcement for identical tokens" + ) + parser.add_argument( + "--use-raw-tokens", + action="store_true", + default=False, + help="Use convert_ids_to_tokens instead of decode, should be False" + ) + parser.add_argument( + "--enable-special-token-mapping", + action="store_true", + default=True, + help="Enable mapping of similar special tokens" + ) + parser.add_argument( + "--disable-special-token-mapping", + action="store_false", + dest="enable_special_token_mapping", + help="Disable special token mapping" + ) + parser.add_argument( + "--use-canonicalization", + action="store_true", + default=False, + help="Apply token canonicalization before processing to normalize different tokenizer representations (e.g., Ġ vs ▁ prefixes, Ċ vs \\n)" + ) + + # Numeric parameters + parser.add_argument( + "--tokens-to-cut", + type=int, + default=4, + help="Maximum number of tokens to consider for multi-token mappings" + ) + parser.add_argument( + "--top-k", + type=int, + default=32, + help="Number of top projections to keep for each token" + ) + parser.add_argument( + "--special-token-similarity-threshold", + type=float, + default=0.3, + help="Minimum similarity threshold for special token matching" + ) + parser.add_argument( + "--special-token-top-k", + type=int, + default=None, + help="Top K matches for each special token (defaults to --top-k value)" + ) + + # File paths + parser.add_argument( + "--initial-projection-path", + type=str, + default=None, + help="Path to initial projection map to load and extend" + ) + parser.add_argument( + "--output-dir", + type=str, + default="cross_tokenizer_data", + help="Output directory for saving projection maps" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + # Parse command line arguments + args = parse_arguments() + # Model names from arguments + teacher_model_name = args.teacher_model + student_model_name = args.student_model + USE_CANONICALIZATION = args.use_canonicalization + + + tokenizer_student = AutoTokenizer.from_pretrained(student_model_name) + tokenizer_teacher = AutoTokenizer.from_pretrained(teacher_model_name) + + tokenizer_student_total_vocab_size = len(tokenizer_student) + tokenizer_teacher_total_vocab_size = len(tokenizer_teacher) + model_A_config = AutoConfig.from_pretrained(student_model_name) + model_B_config = AutoConfig.from_pretrained(teacher_model_name) + + tokens_student = [apply_canonicalization_if_enabled(tokenizer_student.convert_ids_to_tokens([i])[0], USE_CANONICALIZATION) for i in range(tokenizer_student_total_vocab_size)] + tokens_teacher = [apply_canonicalization_if_enabled(tokenizer_teacher.convert_ids_to_tokens([j])[0], USE_CANONICALIZATION) for j in range(tokenizer_teacher_total_vocab_size)] + + map_teacher_token_to_idx = {token: j for j, token in enumerate(tokens_teacher)} + + # Find indices in student and teacher where the tokens are identical + match_indices_student = [] + match_indices_teacher = [] + for i, token_student in enumerate(tokens_student): + if token_student in map_teacher_token_to_idx: + j = map_teacher_token_to_idx[token_student] + match_indices_student.append(i) + match_indices_teacher.append(j) + + if match_indices_student: + print(f"Found {len(match_indices_student)} exact matches. Setting perfect 1-to-1 mappings.") + + # load intial projection map + initial_projection_path = args.initial_projection_path + if initial_projection_path is not None: + initial_projection_map = torch.load(initial_projection_path) + else: + initial_projection_map = None + + # go through token in projection map. For each token present in match_indices_student, set it's likelihoods and incices to 1.0 and the exact match teacher token + non_exact_map_tokens = list(range(len(initial_projection_map["likelihoods"]))) + all_student_token_ids = list(range(len(initial_projection_map["likelihoods"]))) + + show_remapping = 5 + if show_remapping > 0: + print(f"Showing remapping for the last {show_remapping} exact matches.") + else: + print(f"Not showing remapping.") + + for i, exact_token_student in enumerate(match_indices_student): + exact_token_teacher = match_indices_teacher[i] + + index_ = all_student_token_ids.index(exact_token_student) + likelihoods = initial_projection_map["likelihoods"][index_] + indices = initial_projection_map["indices"][index_] + + if len(match_indices_student) - i <= show_remapping: + print(f"prior to remapping: likelihoods {likelihoods} indices {indices}") + + topk = indices.shape[0] + + remapped_indices = torch.ones_like(indices) * -1 + remapped_likelihoods = torch.zeros_like(likelihoods) + + remapped_likelihoods[0] = 1.0 + remapped_indices[0] = exact_token_teacher + + # if exact_token_student == 5159: + # import pdb + # pdb.set_trace() + + initial_projection_map["likelihoods"][index_] = remapped_likelihoods + initial_projection_map["indices"][index_] = remapped_indices + + + if len(match_indices_student) - i <= show_remapping: + print(f'after remapping {tokens_student[exact_token_student]}:{exact_token_student} -> {tokens_teacher[exact_token_teacher]}:{exact_token_teacher}: likelihoods {initial_projection_map["likelihoods"][index_]} indices {initial_projection_map["indices"][index_]}') + non_exact_map_tokens.remove(index_) + + + # import pdb + # pdb.set_trace() + # print(f"non exact map tokens: {non_exact_map_tokens}") + # pdb.set_trace() + save_path = args.initial_projection_path.split(".")[0] + "_exact_map_remapped.pt" + torch.save(initial_projection_map, save_path) + print(f"Saved remapped projection map to: {save_path}") + print(f"remapped {len(match_indices_student)} tokens. Retained remaining {len(non_exact_map_tokens)} tokens as is.") \ No newline at end of file diff --git a/nemo_rl/algorithms/x_token/sort_and_cut_projection_matrix.py b/nemo_rl/algorithms/x_token/sort_and_cut_projection_matrix.py new file mode 100644 index 0000000000..31c0823037 --- /dev/null +++ b/nemo_rl/algorithms/x_token/sort_and_cut_projection_matrix.py @@ -0,0 +1,438 @@ +import torch +import os +import argparse +import tqdm +from transformers import AutoTokenizer, AutoConfig +import pdb; + +def sinkhorn_one_dim(A, n_iters=1): + """Apply Sinkhorn normalization to make each row sum to 1.""" + for _ in range(n_iters): + # A = A / (A.sum(dim=1, keepdim=True) + 1e-6) + row_sums = A.sum(dim=1, keepdim=True) + safe_row_sums = torch.where(row_sums == 0, torch.ones_like(row_sums), row_sums) + A = A / safe_row_sums + return A + +def clean_model_name_for_filename(name: str) -> str: + """Removes parameter counts and common suffixes from model names for cleaner filenames.""" + import re + # Removes patterns like -8B, -1.5B, -4b, -125m etc. + cleaned_name = re.sub(r'-?[0-9\.]+[bBmB]', '', name, flags=re.IGNORECASE) + # Remove common suffixes + cleaned_name = cleaned_name.replace('-Base', '').replace('-it', '').replace('-Instruct', '') + # Clean up any leading/trailing hyphens that might result + cleaned_name = cleaned_name.strip('-_') + return cleaned_name + +def sort_and_cut_projection_matrix(input_path, output_path, new_top_k, preserve_last=False, verbose=True): + """ + Load a projection matrix, sort each row by weight values, and save with new top_k cutoff. + + Args: + input_path: Path to input projection matrix file + output_path: Path to save the new projection matrix + new_top_k: New top_k value for cutoff + preserve_last: If True, always preserve the last column as the final element + verbose: Whether to print progress information + """ + if verbose: + print(f"Loading projection matrix from: {input_path}") + + # Load the projection matrix + projection_data = torch.load(input_path, map_location='cpu', weights_only=False) + + if not isinstance(projection_data, dict) or 'indices' not in projection_data or 'likelihoods' not in projection_data: + raise ValueError("Input file must contain a dictionary with 'indices' and 'likelihoods' keys") + + original_indices = projection_data['indices'] # Shape: [vocab_size, original_top_k] + original_likelihoods = projection_data['likelihoods'] # Shape: [vocab_size, original_top_k] + + vocab_size, original_top_k = original_indices.shape + + if verbose: + print(f"Original matrix shape: {original_indices.shape}") + print(f"Original top_k: {original_top_k}") + print(f"New top_k: {new_top_k}") + print(f"Preserve last column: {preserve_last}") + # pdb.set_trace() + + if new_top_k > original_top_k: + print(f"Warning: New top_k ({new_top_k}) is larger than original top_k ({original_top_k})") + print(f"Will pad with -1 indices and 0.0 likelihoods") + effective_top_k = original_top_k + else: + effective_top_k = new_top_k + + # Initialize new tensors + new_indices = torch.full((vocab_size, new_top_k), -1, dtype=original_indices.dtype) + new_likelihoods = torch.zeros((vocab_size, new_top_k), dtype=original_likelihoods.dtype) + + # Statistics tracking + rows_with_order_change = 0 + significant_components_count = [0] * min(new_top_k, 10) # Track up to 10 components + threshold_for_significance = 0.2 # Threshold for considering a component "significant" + # Track position of maximum element in original ordering + max_element_positions = {} # position -> count + # Track preserve_last statistics + rows_with_preserved_last = 0 + # Track specifically when max element is in the last column + rows_with_max_in_last_column = 0 + # Track position of maximum element in final sorted and trimmed matrix + final_max_element_positions = {} # position -> count + + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + # threshold_for_significance = 0.05 # Threshold for considering a component "significant" + + if verbose: + print("Sorting and cutting each row...") + + # Process each row (each source token) + last_element_trick_count = 0 + for row_idx in tqdm.tqdm(range(vocab_size), desc="Processing rows", disable=not verbose): + row_indices = original_indices[row_idx] # [original_top_k] + row_likelihoods = original_likelihoods[row_idx] # [original_top_k] + + # Filter out invalid indices (-1) and zero likelihoods + valid_mask = (row_indices != -1) & (row_likelihoods > 0) + + if valid_mask.any(): + valid_indices = row_indices[valid_mask] + valid_likelihoods = row_likelihoods[valid_mask] + + # Track position of maximum element in original ordering + max_pos = torch.argmax(valid_likelihoods).item() + if max_pos not in max_element_positions: + max_element_positions[max_pos] = 0 + max_element_positions[max_pos] += 1 + + # Check if max element is specifically in the last column + # Find the actual maximum value in the original row (including invalid entries) + original_max_pos = torch.argmax(row_likelihoods).item() + if original_max_pos == original_top_k - 1: + # Only count if the last position actually has valid data + last_index = row_indices[original_top_k - 1] + last_likelihood = row_likelihoods[original_top_k - 1] + if last_index != -1 and last_likelihood > 0: + rows_with_max_in_last_column += 1 + + if preserve_last and new_top_k >= 1: + # Handle preserve_last case + last_index = original_indices[row_idx, original_top_k - 1] + last_likelihood = original_likelihoods[row_idx, original_top_k - 1] + + if new_top_k == 1: + # Special case: only keep the last element + if last_index != -1 and last_likelihood > 0: + new_indices[row_idx, 0] = last_index + new_likelihoods[row_idx, 0] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant components + if last_likelihood >= threshold_for_significance: + significant_components_count[0] += 1 + else: + # General case: sort first (original_top_k-1) elements, then add last element + elements_to_sort = min(len(valid_likelihoods), original_top_k - 1) + if elements_to_sort > 0: + # Get elements excluding the last position in original matrix + sort_mask = torch.arange(len(valid_likelihoods)) < elements_to_sort + if sort_mask.any(): + sortable_indices = valid_indices[sort_mask] + sortable_likelihoods = valid_likelihoods[sort_mask] + + # Sort the non-last elements + sorted_likelihoods, sort_order = torch.sort(sortable_likelihoods, descending=True) + sorted_indices = sortable_indices[sort_order] + + # Check if order changed in the sortable portion + original_order = torch.arange(len(sortable_likelihoods)) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + # Take top (new_top_k - 1) elements from sorted portion + num_from_sorted = min(len(sorted_indices), new_top_k - 1) + + new_indices[row_idx, :num_from_sorted] = sorted_indices[:num_from_sorted] + new_likelihoods[row_idx, :num_from_sorted] = sorted_likelihoods[:num_from_sorted] + + # Count significant components from sorted portion + for comp_idx in range(min(num_from_sorted, len(significant_components_count) - 1)): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + + # Always put the last element at the end (if valid) + + if last_index != -1 and last_likelihood > 0: + last_element_trick_count += 1 + new_indices[row_idx, new_top_k - 1] = last_index + new_likelihoods[row_idx, new_top_k - 1] = last_likelihood + rows_with_preserved_last += 1 + + # Count significant component for the preserved last element + if new_top_k - 1 < len(significant_components_count): + if last_likelihood >= threshold_for_significance: + significant_components_count[new_top_k - 1] += 1 + + else: + # Original logic: sort all elements normally + # Check if order changed by comparing original vs sorted order + original_order = torch.arange(len(valid_likelihoods)) + sorted_likelihoods, sort_order = torch.sort(valid_likelihoods, descending=True) + + # Check if the order changed (not just sorted, but actually different) + if not torch.equal(sort_order, original_order): + rows_with_order_change += 1 + + sorted_indices = valid_indices[sort_order] + + # Take top effective_top_k elements + num_to_take = min(len(sorted_indices), effective_top_k) + + new_indices[row_idx, :num_to_take] = sorted_indices[:num_to_take] + new_likelihoods[row_idx, :num_to_take] = sorted_likelihoods[:num_to_take] + # pdb.set_trace() + # Count significant components (components above threshold) + for comp_idx in range(min(num_to_take, len(significant_components_count))): + if sorted_likelihoods[comp_idx] >= threshold_for_significance: + significant_components_count[comp_idx] += 1 + # if significant_components_count[1] > 0.0: + # pdb.set_trace() + + # If new_top_k > original_top_k, the tensors are already padded with -1 and 0.0 + + # Apply Sinkhorn normalization to the final matrix + print(f"last element trick count: {last_element_trick_count}") + if verbose: + print("Applying Sinkhorn normalization...") + + # Apply normalization only to non-zero values to preserve sparsity structure + normalized_likelihoods = sinkhorn_one_dim(new_likelihoods.clone(), n_iters=1) + + # Calculate final maximum element position statistics after sorting and normalization + if verbose: + print("Calculating final maximum element position statistics...") + + for row_idx in range(vocab_size): + row_likelihoods = normalized_likelihoods[row_idx] + # Filter out zero likelihoods + valid_mask = row_likelihoods > 0 + if valid_mask.any(): + valid_likelihoods = row_likelihoods[valid_mask] + # Find position of maximum element in the final matrix + max_pos_in_valid = torch.argmax(valid_likelihoods).item() + # Convert back to original position in the row + valid_positions = torch.nonzero(valid_mask).squeeze(-1) + actual_max_pos = valid_positions[max_pos_in_valid].item() + + if actual_max_pos not in final_max_element_positions: + final_max_element_positions[actual_max_pos] = 0 + final_max_element_positions[actual_max_pos] += 1 + + # Create output dictionary with same format as input + output_data = { + 'indices': new_indices, + 'likelihoods': normalized_likelihoods, + } + + # Copy over any additional metadata + for key in projection_data: + if key not in ['indices', 'likelihoods']: + output_data[key] = projection_data[key] + + # Save the new projection matrix + torch.save(output_data, output_path) + + if verbose: + print(f"Saved sorted and cut projection matrix to: {output_path}") + print(f"New matrix shape: {new_indices.shape}") + + # Show basic statistics + non_zero_counts = (new_likelihoods > 0).sum(dim=1) + avg_non_zero = non_zero_counts.float().mean().item() + print(f"Average non-zero entries per row: {avg_non_zero:.2f}") + print(f"Rows with max entries ({new_top_k}): {(non_zero_counts == new_top_k).sum().item()}") + + # Show ordering statistics + print(f"\n=== Ordering Statistics ===") + print(f"Rows with changed order after sorting: {rows_with_order_change:,} / {vocab_size:,} ({100*rows_with_order_change/vocab_size:.1f}%)") + if preserve_last: + print(f"Rows with preserved last element: {rows_with_preserved_last:,} / {vocab_size:,} ({100*rows_with_preserved_last/vocab_size:.1f}%)") + + # Show last column maximum element statistics + print(f"\n=== Last Column Maximum Element Statistics ===") + total_rows_with_data = sum(max_element_positions.values()) + if total_rows_with_data > 0: + percentage_last_max = 100 * rows_with_max_in_last_column / total_rows_with_data + print(f"Rows with maximum element in LAST column: {rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({percentage_last_max:.1f}%)") + print(f"Rows with maximum element in NON-LAST columns: {total_rows_with_data - rows_with_max_in_last_column:,} / {total_rows_with_data:,} ({100 - percentage_last_max:.1f}%)") + else: + print(f"No valid data found to analyze last column statistics") + + # Show maximum element position distribution + print(f"\n=== Maximum Element Position Distribution (Original Ordering) ===") + total_rows_with_data = sum(max_element_positions.values()) + print(f"Total rows with valid data: {total_rows_with_data:,}") + + # Sort positions for ordered display + sorted_positions = sorted(max_element_positions.keys()) + for pos in sorted_positions[:20]: # Show up to first 20 positions + count = max_element_positions[pos] + percentage = 100 * count / total_rows_with_data if total_rows_with_data > 0 else 0 + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_positions) > 20: + remaining_count = sum(max_element_positions[pos] for pos in sorted_positions[20:]) + remaining_percentage = 100 * remaining_count / total_rows_with_data if total_rows_with_data > 0 else 0 + print(f"Rows with max element in positions 21+: {remaining_count:,} / {total_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show final maximum element position distribution (after sorting and normalization) + print(f"\n=== Maximum Element Position Distribution (Final Sorted & Normalized Matrix) ===") + total_final_rows_with_data = sum(final_max_element_positions.values()) + print(f"Total rows with valid data: {total_final_rows_with_data:,}") + + if total_final_rows_with_data > 0: + # Sort positions for ordered display + sorted_final_positions = sorted(final_max_element_positions.keys()) + for pos in sorted_final_positions[:min(new_top_k, 20)]: # Show up to new_top_k or 20 positions + count = final_max_element_positions[pos] + percentage = 100 * count / total_final_rows_with_data + ordinal = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"][pos] if pos < 10 else f"{pos+1}th" + print(f"Rows with max element in {ordinal} position: {count:,} / {total_final_rows_with_data:,} ({percentage:.1f}%)") + + if len(sorted_final_positions) > min(new_top_k, 20): + remaining_count = sum(final_max_element_positions[pos] for pos in sorted_final_positions[min(new_top_k, 20):]) + remaining_percentage = 100 * remaining_count / total_final_rows_with_data + print(f"Rows with max element in positions {min(new_top_k, 20)+1}+: {remaining_count:,} / {total_final_rows_with_data:,} ({remaining_percentage:.1f}%)") + + # Show significant components statistics + print(f"\n=== Significant Components Statistics (threshold >= {threshold_for_significance}) ===") + component_names = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th"] + for i, count in enumerate(significant_components_count): + percentage = 100 * count / vocab_size if vocab_size > 0 else 0 + print(f"Rows with significant {component_names[i]} component: {count:,} / {vocab_size:,} ({percentage:.1f}%)") + + # Additional analysis: distribution of likelihood values (after normalization) + all_likelihoods = normalized_likelihoods[normalized_likelihoods > 0] + if len(all_likelihoods) > 0: + print(f"\n=== Likelihood Distribution ===") + print(f"Total non-zero likelihoods: {len(all_likelihoods):,}") + print(f"Mean likelihood: {all_likelihoods.mean().item():.4f}") + print(f"Median likelihood: {all_likelihoods.median().item():.4f}") + print(f"Min likelihood: {all_likelihoods.min().item():.4f}") + print(f"Max likelihood: {all_likelihoods.max().item():.4f}") + + # Show percentiles - convert to float for quantile calculation + percentiles = [90, 95, 99] + all_likelihoods_float = all_likelihoods.float() + for p in percentiles: + val = torch.quantile(all_likelihoods_float, p/100.0).item() + print(f"{p}th percentile: {val:.4f}") + + # Show how many rows have multiple significant components + print(f"\n=== Multi-Component Analysis ===") + rows_with_multiple_significant = 0 + for row_idx in range(vocab_size): + significant_in_row = (normalized_likelihoods[row_idx] >= threshold_for_significance).sum().item() + if significant_in_row >= 2: + rows_with_multiple_significant += 1 + + percentage_multi = 100 * rows_with_multiple_significant / vocab_size if vocab_size > 0 else 0 + print(f"Rows with 2+ significant components: {rows_with_multiple_significant:,} / {vocab_size:,} ({percentage_multi:.1f}%)") + + # Show normalization effect + print(f"\n=== Normalization Effect ===") + # Calculate row sums for ALL rows (including zero rows) + all_row_sums = normalized_likelihoods.sum(dim=1) + non_zero_rows = (normalized_likelihoods > 0).any(dim=1) + zero_rows = ~non_zero_rows + + print(f"Total rows: {vocab_size:,}") + print(f"Rows with non-zero entries: {non_zero_rows.sum().item():,}") + print(f"Rows with all zeros: {zero_rows.sum().item():,}") + + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + print(f"\nNon-zero rows statistics:") + print(f" Mean sum: {row_sums_nonzero.mean().item():.6f}") + print(f" Std sum: {row_sums_nonzero.std().item():.6f}") + print(f" Min sum: {row_sums_nonzero.min().item():.6f}") + print(f" Max sum: {row_sums_nonzero.max().item():.6f}") + + # Check how many rows don't sum to 1 (with different tolerance levels) + tolerances = [1e-6, 1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + perfect_rows = (torch.abs(row_sums_nonzero - 1.0) < tol).sum().item() + imperfect_rows = len(row_sums_nonzero) - perfect_rows + percentage_imperfect = 100 * imperfect_rows / len(row_sums_nonzero) + print(f" Rows NOT summing to 1.0 (tol={tol}): {imperfect_rows:,}/{len(row_sums_nonzero):,} ({percentage_imperfect:.2f}%)") + + # Show distribution of row sums that deviate from 1.0 + if non_zero_rows.any(): + row_sums_nonzero = all_row_sums[non_zero_rows] + deviations = torch.abs(row_sums_nonzero - 1.0) + significant_deviations = deviations > 1e-3 + + if significant_deviations.any(): + print(f"\nRows with significant deviations from 1.0 (>0.001): {significant_deviations.sum().item():,}") + worst_deviations = deviations[significant_deviations] + print(f" Mean deviation: {worst_deviations.mean().item():.6f}") + print(f" Max deviation: {worst_deviations.max().item():.6f}") + + # Show some examples of problematic rows + worst_indices = torch.topk(deviations, k=min(5, len(deviations)))[1] + print(f" Worst {min(5, len(worst_indices))} row examples:") + for i, idx in enumerate(worst_indices): + actual_row_idx = torch.nonzero(non_zero_rows)[idx].item() + sum_val = row_sums_nonzero[idx].item() + deviation = deviations[idx].item() + non_zero_count = (normalized_likelihoods[actual_row_idx] > 0).sum().item() + print(f" Row {actual_row_idx}: sum={sum_val:.6f}, deviation={deviation:.6f}, non_zeros={non_zero_count}") + else: + print(f"\nAll non-zero rows sum very close to 1.0 (deviation < 0.001)") + +def main(): + parser = argparse.ArgumentParser(description="Sort and cut projection matrix by top_k") + parser.add_argument("input_path", help="Path to input projection matrix file") + parser.add_argument("--top_k", type=int, required=True, help="New top_k value for cutoff") + parser.add_argument("--output_path", help="Output path (auto-generated if not specified)") + parser.add_argument("--preserve_last", action="store_true", help="Always preserve the last column as the final element") + parser.add_argument("--quiet", "-q", action="store_true", help="Suppress progress output") + # python sort_and_cut_projection_matrix.py /lustre/fsw/portfolios/nvr/projects/nvr_lpr_llm/users/pmolchanov/xtoken/models/runs/s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices/learned_projection_map_latest.pt --top_k 8 --output_path cross_tokenizer_data/projection_matrix_learned_llama_qwen_top8.pt --preserve_last + #s4_l1q4b_lr0_kl1_ce0_k1_emb_top10_3_learn_qa2_transformation_matrices + args = parser.parse_args() + + # Auto-generate output path if not specified + if args.output_path is None: + input_dir = os.path.dirname(args.input_path) + input_filename = os.path.basename(args.input_path) + + # Extract base name and extension + base_name, ext = os.path.splitext(input_filename) + + # Remove old top_k info if present + import re + base_name = re.sub(r'_top_\d+', '', base_name) + + # Add new top_k info and preserve_last flag + suffix = "_sorted" + if args.preserve_last: + suffix += "_preservelast" + output_filename = f"{base_name}_top_{args.top_k}{suffix}{ext}" + args.output_path = os.path.join(input_dir, output_filename) + + # Ensure output directory exists + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + + # Process the matrix + sort_and_cut_projection_matrix( + args.input_path, + args.output_path, + args.top_k, + preserve_last=args.preserve_last, + verbose=not args.quiet + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/nemo_rl/algorithms/x_token/tokenalign.py b/nemo_rl/algorithms/x_token/tokenalign.py new file mode 100644 index 0000000000..e874cc0dba --- /dev/null +++ b/nemo_rl/algorithms/x_token/tokenalign.py @@ -0,0 +1,4837 @@ +import copy +import json +import logging +import concurrent.futures +import torch.nn as nn +import tokenizers +import tokenizers.decoders +import tokenizers.normalizers +import tokenizers.pre_tokenizers +from tokenizers import Tokenizer +from transformers import AutoTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM , AutoModel, AutoConfig +import numpy as np +from typing import Dict, List, Union, Callable, Optional +import os + +import torch + +try: + from numba import njit + _NUMBA_AVAILABLE = True +except ImportError: + _NUMBA_AVAILABLE = False + +##### define the format of projection matrix +##### go for dense as it is easier to train, and gradient is only computed for top_k +##### we will not have "A_to_B" and "B_to_A" to simplify, no bidirectional projection + +#### skip backprop if accuracy of alignment is <0.9 + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +logger = logging.getLogger(__name__) + + +if _NUMBA_AVAILABLE: + @njit(cache=True) + def _dp_core_numba(ids1, ids2, joined1, joined2, n1, n2, + exact_match_score, gap_penalty, comb_mul, max_comb_len): + """Numba-accelerated DP core for token alignment. + + Uses the same algorithm as align_tokens_with_combinations_numpy but + with integer ID comparisons instead of Python string operations. + + Trace codes: 0=start, 1=diag, 2=up, 3=left, + 10+k = comb_s1_over_s2_k, 20+k = comb_s2_over_s1_k + """ + INVALID = np.int64(-1) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.zeros((n1 + 1, n2 + 1), dtype=np.int32) + + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 2 + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 3 + + for i in range(1, n1 + 1): + id_i = ids1[i - 1] + for j in range(1, n2 + 1): + id_j = ids2[j - 1] + + if id_i == id_j: + best = dp[i - 1, j - 1] + exact_match_score + else: + best = dp[i - 1, j - 1] - exact_match_score + best_m = np.int32(1) + + s = dp[i - 1, j] + gap_penalty + if s > best: + best = s + best_m = np.int32(2) + + s = dp[i, j - 1] + gap_penalty + if s > best: + best = s + best_m = np.int32(3) + + k_max_s2 = min(j, max_comb_len) + for k in range(2, k_max_s2 + 1): + jid = joined2[j, k] + if jid != INVALID and id_i == jid: + s = dp[i - 1, j - k] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(10 + k) + + k_max_s1 = min(i, max_comb_len) + for k in range(2, k_max_s1 + 1): + jid = joined1[i, k] + if jid != INVALID and id_j == jid: + s = dp[i - k, j - 1] + comb_mul * np.float32(k) + if s > best: + best = s + best_m = np.int32(20 + k) + + dp[i, j] = best + trace[i, j] = best_m + + return dp, trace +else: + _dp_core_numba = None + + +class TokenAligner(nn.Module): + def __init__(self, max_comb_len=4, teacher_tokenizer_name=None, student_tokenizer_name=None, init_hf_tokenizers=True, track_rules=False, projection_matrix_multiplier=1.0, enable_scale_trick=None): + super().__init__() + self.teacher_tokenizer_name = teacher_tokenizer_name + self.student_tokenizer_name = student_tokenizer_name + self.track_rules = track_rules # Control whether to track alignment rules + self.projection_matrix_multiplier = projection_matrix_multiplier # Multiplier for projection matrix scaling + self.enable_scale_trick = enable_scale_trick # Override for SCALE_TRICK (if None, use default False) + + if init_hf_tokenizers: + self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_tokenizer_name) + self.student_tokenizer = AutoTokenizer.from_pretrained(student_tokenizer_name) + if self.teacher_tokenizer.pad_token is None: + self.teacher_tokenizer.pad_token = self.teacher_tokenizer.eos_token + if self.student_tokenizer.pad_token is None: + self.student_tokenizer.pad_token = self.student_tokenizer.eos_token + else: + self.teacher_tokenizer = None + self.student_tokenizer = None + + self.forward_rules = set() # (seq1_tuple, seq2_tuple) + self.reverse_rules = set() # (seq2_tuple, seq1_tuple) + self.max_combination_len = max_comb_len + self.sparse_transformation_matrix = None + # Cached CSR for dense top-k projection (built from indices/values) to avoid scatter path + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + # Precomputed canonical ID maps (built by precompute_canonical_maps) + self._student_canon_map = None + self._teacher_canon_map = None + self._canon_id_to_str = None + + def precompute_canonical_maps(self): + """Build token_id → canonical_string lookup tables for both tokenizers. + + Call once at startup. After this, align_fast() can skip + convert_ids_to_tokens and _canonicalize_sequence entirely. + """ + import time as _time + _t0 = _time.time() + + canon_str_to_id: dict[str, int] = {} + next_id = [0] + + def _get_canon_id(s: str) -> int: + cid = canon_str_to_id.get(s) + if cid is None: + cid = next_id[0] + canon_str_to_id[s] = cid + next_id[0] += 1 + return cid + + student_vocab_size = len(self.student_tokenizer) + teacher_vocab_size = len(self.teacher_tokenizer) + + student_map = np.zeros(student_vocab_size, dtype=np.int64) + for tid in range(student_vocab_size): + tok = self.student_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + student_map[tid] = _get_canon_id(canon) + + teacher_map = np.zeros(teacher_vocab_size, dtype=np.int64) + for tid in range(teacher_vocab_size): + tok = self.teacher_tokenizer.convert_ids_to_tokens(tid) + canon = self._canonical_token(tok) + teacher_map[tid] = _get_canon_id(canon) + + self._student_canon_map = student_map + self._teacher_canon_map = teacher_map + self._canon_id_to_str = {v: k for k, v in canon_str_to_id.items()} + + _t1 = _time.time() + print(f" [TokenAligner] Precomputed canonical maps in {_t1-_t0:.2f}s " + f"(student_vocab={student_vocab_size}, teacher_vocab={teacher_vocab_size}, " + f"unique_canonical={len(canon_str_to_id)})", flush=True) + + def align_fast(self, student_ids, teacher_ids, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + chunk_size=128, + post_process=True, + anchor_lengths=[3,], + ignore_leading_char_diff=False): + """Fast alignment using precomputed canonical ID maps. + + Skips convert_ids_to_tokens and _canonicalize_sequence by looking up + canonical strings directly from token IDs via precomputed numpy arrays. + Falls back to regular align() if precomputed maps are not available. + """ + if self._student_canon_map is None: + return self.align(student_ids, teacher_ids, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + chunk_size=chunk_size, + post_process=post_process, + anchor_lengths=anchor_lengths, + ignore_leading_char_diff=ignore_leading_char_diff) + + if isinstance(student_ids, torch.Tensor): + student_ids = student_ids.cpu().numpy() + if isinstance(teacher_ids, torch.Tensor): + teacher_ids = teacher_ids.cpu().numpy() + + if student_ids.ndim == 1: + student_ids = student_ids[np.newaxis, :] + teacher_ids = teacher_ids[np.newaxis, :] + + import time as _time + _t_lookup_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + all_aligned_pairs = [] + for i in range(student_ids.shape[0]): + s_ids = student_ids[i] + t_ids = teacher_ids[i] + + _tl0 = _time.time() + s_canon_strs = [self._canon_id_to_str[self._student_canon_map[tid]] for tid in s_ids] + t_canon_strs = [self._canon_id_to_str[self._teacher_canon_map[tid]] for tid in t_ids] + _tl1 = _time.time() + _t_lookup_total += _tl1 - _tl0 + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(s_canon_strs, t_canon_strs, **align_kwargs) + _tl2 = _time.time() + _t_anchors_dp_total += _tl2 - _tl1 + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tl3 = _time.time() + _t_postprocess_total += _tl3 - _tl2 + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, + ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value + in zip(aligned_pairs, mask) + ] + _tl4 = _time.time() + _t_mask_total += _tl4 - _tl3 + + all_aligned_pairs.append(aligned_pairs) + + n = student_ids.shape[0] + _t_total = _t_lookup_total + _t_anchors_dp_total + _t_postprocess_total + _t_mask_total + if _t_total > 0.5 or n > 1: + print(f" [align_fast timing] lookup={_t_lookup_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s, " + f"total={_t_total:.3f}s (n={n})", flush=True) + + return all_aligned_pairs + + def _convert_student_tokens_to_teacher_tokens(self, student_tokens: torch.Tensor) -> torch.Tensor: + device = student_tokens.device + dtype = student_tokens.dtype + if student_tokens.device != "cpu": + student_tokens = student_tokens.cpu() + + # Decode each sequence in the batch, not each individual token + text = [self.student_tokenizer.decode(sequence.tolist(), skip_special_tokens=True) for sequence in student_tokens] + teacher_tokens = [self.teacher_tokenizer.encode(text_single, max_length=student_tokens.shape[1], padding='max_length', truncation=True, return_tensors='pt').squeeze(0) for text_single in text] + + teacher_tokens = torch.stack(teacher_tokens).to(device).to(dtype) + return teacher_tokens + + def _load_logits_projection_map( + self, + folder_location: str = "cross_tokenizer_data", + file_path: str = None, + top_k: int = 100, + device: str = "cuda", + use_sparse_format: bool = False, + learnable: bool = False, + ): + """ + Load projection map for cross-tokenizer likelihood projection. + Always creates student→teacher mapping. + + Args: + folder_location: Directory containing the projection files + file_path: Specific file path (overrides folder_location) + top_k: Number of top entries per row (only used for old format) + device: Device to load tensors on + use_sparse_format: If True, load sparse transformation matrix format (from multi-token mapping) + If False, load old dense indices/values format + learnable: If True, make the transformation matrix learnable + """ + self.learnable = learnable + if use_sparse_format: + # Load sparse transformation matrix format + if file_path is None: + file_path = f"{folder_location}/transformation_counts_via_multitoken.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Sparse transformation matrix file not found: {file_path}. Please generate it first.") + + # Load transformation counts dictionary + transformation_counts = torch.load(file_path, map_location='cpu', weights_only=False) + + # Get tokenizer vocab sizes + teacher_vocab_size = len(self.teacher_tokenizer) if self.teacher_tokenizer else 151669 # fallback + student_vocab_size = len(self.student_tokenizer) if self.student_tokenizer else 128256 # fallback + if 1: + # get vocab sizes from autoconfig + if "gemma" not in self.teacher_tokenizer_name.lower() and "qwen3.5" not in self.teacher_tokenizer_name.lower(): + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + else: + teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).text_config.vocab_size + if "gemma" not in self.student_tokenizer_name.lower() and "qwen3.5" not in self.student_tokenizer_name.lower(): + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + else: + student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).text_config.vocab_size + # teacher_vocab_size = AutoConfig.from_pretrained(self.teacher_tokenizer_name).vocab_size + # student_vocab_size = AutoConfig.from_pretrained(self.student_tokenizer_name).vocab_size + + + # Debug vocab sizes + print(f"Teacher vocab size: {teacher_vocab_size}, Student vocab size: {student_vocab_size}") + + # Convert dictionary to sparse tensor + if transformation_counts: + + + indices = list(transformation_counts.keys()) + values = list(transformation_counts.values()) + + student_indices = [idx[0] for idx in indices] + teacher_indices = [idx[1] for idx in indices] + + # Always create student→teacher mapping: rows = student vocab, cols = teacher vocab + indices_tensor = torch.LongTensor([student_indices, teacher_indices]) + values_tensor = torch.FloatTensor(values)/self.projection_matrix_multiplier + matrix_shape = (student_vocab_size, teacher_vocab_size) + + print(f"Creating sparse matrix: student→teacher ({student_vocab_size} x {teacher_vocab_size})") + + sparse_transformation_matrix = torch.sparse_coo_tensor( + indices_tensor, + values_tensor, + (student_vocab_size, teacher_vocab_size), # student_vocab × teacher_vocab + device=device, + dtype=torch.float32 + ) + + # Optionally make the sparse matrix learnable (values only) + if learnable: + self.sparse_transformation_matrix = nn.Parameter( + sparse_transformation_matrix.coalesce(), requires_grad=True + ) + else: + # Register as buffer for non-learnable parameters (ensures proper device handling) + self.register_buffer('sparse_transformation_matrix', + sparse_transformation_matrix.coalesce(), + persistent=True) + + # Store a flag for downstream code + self.is_sparse_learnable = learnable + print(f"Loaded sparse transformation matrix with {len(transformation_counts)} entries") + else: + # Empty transformation matrix (student→teacher) + matrix_shape = (student_vocab_size, teacher_vocab_size) + + empty_sparse = torch.sparse_coo_tensor( + torch.zeros(2, 0, dtype=torch.long), + torch.zeros(0, dtype=torch.float32), + matrix_shape, + device=device, + ) + + if learnable: + self.sparse_transformation_matrix = nn.Parameter(empty_sparse, requires_grad=True) + else: + # Register as buffer for non-learnable parameters + self.register_buffer('sparse_transformation_matrix', empty_sparse, persistent=True) + + self.is_sparse_learnable = learnable + print("Warning: Empty transformation matrix loaded") + else: + # Load old dense indices/values format + if file_path is None: + file_path = f"{folder_location}/projection_map_Llama-3.1_to_Qwen3_bidirectional_top_10.pt" + + if not os.path.exists(file_path): + raise FileNotFoundError(f"Projection map file not found: {file_path}. Please generate it first.") + + projection_data = torch.load(file_path, map_location='cpu', weights_only=False) + # Always use B_to_A direction for student->teacher projection + # projection_data = projection_data["B_to_A"] + # projection_data = projection_data["A_to_B"] + + indices = projection_data["indices"] + likelihoods = projection_data["likelihoods"]/self.projection_matrix_multiplier + + # Register indices as buffer (always non-learnable) + self.register_buffer('likelihood_projection_indices', indices.to(device), persistent=True) + if learnable: + if 1: + likelihoods = (likelihoods+1e-10).log() + + # Use instance variable if set, otherwise use default (False) + # scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + + # if scale_trick_enabled: + # #trick with last column being multiplier - set to -4.0 + # likelihoods[:,-1] = likelihoods[:,-1]*0.0 - 4.0 + #lets introduce some noise to encourage training. will remove later. + if 0: + likelihoods = likelihoods + torch.randn_like(likelihoods) * 1e-1 + likelihoods = likelihoods/2.0 + + self.likelihood_projection_matrix = nn.Parameter(likelihoods.to(device), requires_grad=True) + # print(self.likelihood_projection_matrix[0]) + # print(self.likelihood_projection_matrix[:,-1]) + # exit() + #add small gaussian noise to the projection matrix + #use log form + else: + # Register as buffer for non-learnable parameters + self.register_buffer('likelihood_projection_matrix', likelihoods.to(device), persistent=True) + + + print(f"Loaded dense projection map with shape {indices.shape}") + # Invalidate cached CSR; will rebuild on first use + self._dense_proj_csr = None + self._dense_proj_csr_device = None + + def create_reverse_projection_matrix(self, device="cuda"): + """ + Create a reverse (transposed) projection matrix for teacher→student projection. + + For sparse format: Transposes the sparse_transformation_matrix from [student_vocab, teacher_vocab] + to [teacher_vocab, student_vocab] + For dense format: Builds a reverse index mapping from teacher tokens to student tokens + + This enables projecting teacher logits into student vocabulary space. + """ + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # Transpose sparse matrix + print("Creating reverse projection matrix (sparse format): teacher→student") + sparse_matrix = self.sparse_transformation_matrix.coalesce() + indices = sparse_matrix.indices() + values = sparse_matrix.values() + + # Swap student and teacher indices (transpose) + transposed_indices = torch.stack([indices[1], indices[0]], dim=0) # Swap rows: [teacher, student] + teacher_vocab_size, student_vocab_size = sparse_matrix.shape[1], sparse_matrix.shape[0] + + reverse_sparse = torch.sparse_coo_tensor( + transposed_indices, + values, + (teacher_vocab_size, student_vocab_size), + device=device, + dtype=torch.float32 + ).coalesce() + + # Store as buffer or parameter based on learnability + if self.is_sparse_learnable: + self.reverse_sparse_transformation_matrix = nn.Parameter(reverse_sparse, requires_grad=True) + else: + self.register_buffer('reverse_sparse_transformation_matrix', reverse_sparse, persistent=True) + + print(f"Created reverse sparse matrix: teacher→student ({teacher_vocab_size} x {student_vocab_size})") + print(f"Reverse matrix has {len(values)} non-zero entries") + + elif hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + # Build reverse index for dense format + print("Creating reverse projection matrix (dense format): teacher→student") + + # Current: likelihood_projection_indices is [student_vocab, topk] + # We need to build: [teacher_vocab, variable_k] where variable_k depends on how many students map to each teacher token + + student_vocab_size = self.likelihood_projection_indices.shape[0] + topk = self.likelihood_projection_indices.shape[1] + + # Infer teacher vocab size from the max index + teacher_vocab_size = self.likelihood_projection_indices.max().item() + 1 + + # Build reverse mapping: for each teacher token, collect all (student_token, value) pairs + from collections import defaultdict + teacher_to_students = defaultdict(list) + + for student_idx in range(student_vocab_size): + for k in range(topk): + teacher_idx = self.likelihood_projection_indices[student_idx, k].item() + if hasattr(self, 'likelihood_projection_matrix'): + value = self.likelihood_projection_matrix[student_idx, k].item() + else: + value = 1.0 # Default value if no matrix + + # Check for valid entries: teacher_idx must be valid, and value must be finite (not -inf) + # If matrix is in log-space, valid log-probs are finite negative values + # Threshold at -20 to filter out padding values like -22.3197 + if teacher_idx >= 0 and value > -20.0: # Skip invalid or padding entries + teacher_to_students[teacher_idx].append((student_idx, value)) + + # Find max number of students mapping to any teacher token + raw_max_students = max([len(v) for v in teacher_to_students.values()]) if teacher_to_students else 1 + print(f"Max students mapping to any teacher token (before filtering): {raw_max_students}") + + # Limit to top-K students per teacher token to avoid explosion + # Keep only the top-K highest probability mappings per teacher + max_students_per_teacher = min(topk, raw_max_students) # Use same topk as forward direction + print(f"Limiting to top-{max_students_per_teacher} students per teacher token") + + # Sort each teacher's student list by value (descending) and keep only top-K + for teacher_idx in teacher_to_students: + student_list = teacher_to_students[teacher_idx] + # Sort by value (descending - higher log-prob = less negative) + student_list_sorted = sorted(student_list, key=lambda x: x[1], reverse=True) + teacher_to_students[teacher_idx] = student_list_sorted[:max_students_per_teacher] + + # Create dense reverse index [teacher_vocab, max_students_per_teacher] + # Use 0 instead of -1 for padding (valid index), with very negative values to nullify contribution + reverse_indices = torch.zeros((teacher_vocab_size, max_students_per_teacher), + dtype=torch.long, device=device) + # Initialize with very negative values (padding sentinel, similar to forward direction) + reverse_values = torch.full((teacher_vocab_size, max_students_per_teacher), -22.3197, + dtype=torch.float32, device=device) + + for teacher_idx, student_list in teacher_to_students.items(): + for k, (student_idx, value) in enumerate(student_list): + reverse_indices[teacher_idx, k] = student_idx + reverse_values[teacher_idx, k] = value + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + + # Store as buffer or parameter + self.register_buffer('reverse_likelihood_projection_indices', reverse_indices, persistent=True) + if self.learnable: + self.reverse_likelihood_projection_matrix = nn.Parameter(reverse_values, requires_grad=True) + else: + self.register_buffer('reverse_likelihood_projection_matrix', reverse_values, persistent=True) + + print(f"Created reverse dense projection: teacher→student ({teacher_vocab_size} x {max_students_per_teacher})") + else: + raise ValueError("No projection matrix loaded. Cannot create reverse projection.") + + def update_transformation_matrix_from_checkpoint(self, transformation_data, device="cuda"): + """ + Update the transformation matrix from loaded checkpoint data. + + Args: + transformation_data: Dictionary containing 'indices' and 'likelihoods' from checkpoint + device: Device to load the matrix on + + Returns: + bool: True if update was successful, False if skipped due to validation errors + """ + if transformation_data is None: + print("No transformation matrix data to load") + return False + + try: + indices = transformation_data["indices"].to(device) + likelihoods = transformation_data["likelihoods"].to(device) / self.projection_matrix_multiplier + + # Debug: print shapes and check compatibility + max_index = indices.max().item() if indices.numel() > 0 else -1 + min_index = indices.min().item() if indices.numel() > 0 else 0 + print(f"Checkpoint data - indices shape: {indices.shape}, likelihoods shape: {likelihoods.shape}") + print(f"Checkpoint data - indices range: [{min_index}, {max_index}]") + + if hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + current_max_index = self.likelihood_projection_indices.max().item() if self.likelihood_projection_indices.numel() > 0 else -1 + current_min_index = self.likelihood_projection_indices.min().item() if self.likelihood_projection_indices.numel() > 0 else 0 + print(f"Current - indices shape: {self.likelihood_projection_indices.shape}") + print(f"Current - indices range: [{current_min_index}, {current_max_index}]") + if hasattr(self, 'likelihood_projection_matrix') and self.likelihood_projection_matrix is not None: + current_matrix_shape = self.likelihood_projection_matrix.shape if hasattr(self.likelihood_projection_matrix, 'shape') else self.likelihood_projection_matrix.data.shape + print(f"Current - matrix shape: {current_matrix_shape}") + print(f"Current - matrix vocab size (dim 0): {current_matrix_shape[0]}") + + # Check for dimension compatibility before updating + if hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + if indices.shape != self.likelihood_projection_indices.shape: + print(f"WARNING: Indices shape mismatch! Checkpoint: {indices.shape} vs Current: {self.likelihood_projection_indices.shape}") + print("Skipping transformation matrix update due to shape mismatch") + return False + + if hasattr(self, 'likelihood_projection_matrix') and self.likelihood_projection_matrix is not None: + current_matrix_shape = self.likelihood_projection_matrix.shape if hasattr(self.likelihood_projection_matrix, 'shape') else self.likelihood_projection_matrix.data.shape + if likelihoods.shape != current_matrix_shape: + print(f"WARNING: Matrix shape mismatch! Checkpoint: {likelihoods.shape} vs Current: {current_matrix_shape}") + print("Skipping transformation matrix update due to shape mismatch") + return False + + # Additional validation: check if indices contain valid teacher vocabulary indices + # Since we project student→teacher, indices represent teacher vocabulary positions + # Get teacher vocab size from tokenizer or current matrix + max_teacher_vocab = None + + # Try to get teacher vocab size from tokenizer + if hasattr(self, 'teacher_tokenizer') and self.teacher_tokenizer is not None: + max_teacher_vocab = len(self.teacher_tokenizer.get_vocab()) + elif hasattr(self, 'teacher_tokenizer_name') and self.teacher_tokenizer_name is not None: + try: + from transformers import AutoTokenizer + temp_tokenizer = AutoTokenizer.from_pretrained(self.teacher_tokenizer_name, trust_remote_code=True) + max_teacher_vocab = len(temp_tokenizer.get_vocab()) + except Exception as e: + print(f"Warning: Could not load teacher tokenizer to check vocab size: {e}") + + # Fallback: infer from current target vocab size being used + if max_teacher_vocab is None and hasattr(self, 'likelihood_projection_matrix') and self.likelihood_projection_matrix is not None: + current_matrix_shape = self.likelihood_projection_matrix.shape if hasattr(self.likelihood_projection_matrix, 'shape') else self.likelihood_projection_matrix.data.shape + print(f"Warning: Using matrix shape to infer teacher vocab size: {current_matrix_shape}") + # This is likely wrong, but we'll use it as a fallback + max_teacher_vocab = current_matrix_shape[1] if len(current_matrix_shape) > 1 else current_matrix_shape[0] + + if max_teacher_vocab is not None: + max_index = indices.max().item() if indices.numel() > 0 else -1 + min_index = indices.min().item() if indices.numel() > 0 else 0 + if max_index >= max_teacher_vocab or min_index < 0: + print(f"ERROR: Index out of bounds! Indices range [{min_index}, {max_index}] but teacher vocab size is {max_teacher_vocab}") + print("This indicates the transformation matrix was saved with a different teacher tokenizer") + print(f"Current teacher: {getattr(self, 'teacher_tokenizer_name', 'unknown')}") + print("Skipping transformation matrix update to prevent CUDA index errors") + return False + + if 1: + # we store transformation matrix after softmax, so need to redo here + #had a bug before when the very first matrix was loaded correctly, but after restarting the checkpoint it was not, here is the fix + likelihoods = (likelihoods+1e-10).log() + # Check if we're using sparse format or dense format + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + print("Warning: Cannot update sparse transformation matrix from dense checkpoint data") + print("Sparse matrix updates not yet implemented") + return False + + # Update dense format matrices + if hasattr(self, 'likelihood_projection_indices') and hasattr(self, 'likelihood_projection_matrix'): + self.likelihood_projection_indices = indices + + # Handle both learnable and non-learnable cases + if hasattr(self.likelihood_projection_matrix, 'data'): + # It's a Parameter - update the data + self.likelihood_projection_matrix.data = likelihoods + print("Updated learnable transformation matrix from checkpoint") + else: + # It's a regular tensor + self.likelihood_projection_matrix = likelihoods + print("Updated fixed transformation matrix from checkpoint") + + # Invalidate cached CSR; will rebuild on first use + self._dense_proj_csr = None + self._dense_proj_csr_device = None + return True + else: + print("Warning: No existing transformation matrix structure found to update") + return False + + except Exception as e: + print(f"Error updating transformation matrix from checkpoint: {e}") + print("Continuing with original transformation matrix") + return False + + def get_transformation_matrix_for_checkpoint(self): + """ + Get the transformation matrix data for saving to checkpoint. + + Returns: + Dictionary containing 'indices' and 'likelihoods' for checkpoint saving, + or None if no transformation matrix is available. + """ + # Check if we have dense format transformation matrix + if hasattr(self, 'likelihood_projection_indices') and hasattr(self, 'likelihood_projection_matrix'): + if self.likelihood_projection_indices is not None and self.likelihood_projection_matrix is not None: + print(f"TokenAligner.get_transformation_matrix_for_checkpoint:") + print(f" Teacher: {getattr(self, 'teacher_tokenizer_name', 'unknown')}") + print(f" Student: {getattr(self, 'student_tokenizer_name', 'unknown')}") + print(f" Indices shape: {self.likelihood_projection_indices.shape}") + + # Get the matrix data (handle both Parameter and Tensor cases) + if hasattr(self.likelihood_projection_matrix, 'data'): + # It's a Parameter - get the data + matrix_data = self.likelihood_projection_matrix.data + print(f" Matrix type: Parameter, shape: {matrix_data.shape}") + else: + # It's a regular Tensor + matrix_data = self.likelihood_projection_matrix + print(f" Matrix type: Tensor, shape: {matrix_data.shape}") + + print(f" Matrix dtype: {matrix_data.dtype}") + print(f" Projection matrix multiplier: {self.projection_matrix_multiplier}") + + # Apply the projection matrix multiplier for saving (reverse the division done during loading) + likelihoods_for_save = matrix_data * self.projection_matrix_multiplier + + # Apply softmax to get probabilities for saving (reverse the log operation done during loading) + likelihoods_for_save = torch.softmax(likelihoods_for_save, dim=-1) + + print(f" Final likelihoods shape: {likelihoods_for_save.shape}, dtype: {likelihoods_for_save.dtype}") + + return { + "indices": self.likelihood_projection_indices.clone(), + "likelihoods": likelihoods_for_save.clone() + } + + # Check if we have sparse format transformation matrix + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + print("Warning: Saving sparse transformation matrix to checkpoint not yet implemented") + print("Returning None - sparse matrix will not be saved to checkpoint") + return None + + print("No transformation matrix available for checkpoint saving") + return None + + # @staticmethod + # def project_token_likelihoods(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_sparse_format=False, sparse_matrix=None, use_vectorized=True, projection_matrix_multiplier=1.0, gpu_optimized_scatter=True): + # """ + # Projects token likelihoods from a source to a target vocabulary using either dense or sparse projection. + + # Args: + # input_likelihoods: Input likelihood tensor (batch_size, seq_len, source_vocab_size) + # projection_map_indices: Indices for dense format (source_vocab_size, top_k) + # projection_map_values: Values for dense format (source_vocab_size, top_k) + # target_vocab_size: Size of target vocabulary + # device: Device to run computation on + # use_sparse_format: If True, use sparse matrix projection + # sparse_matrix: Sparse transformation matrix (teacher_vocab_size, student_vocab_size) + # use_vectorized: If True (and use_sparse_format=False), use vectorized dense approach; + # If False, use sparse CSR matrix approach (only for dense format) + # gpu_optimized_scatter: If True, uses a more GPU-friendly scatter operation for dense projection. + # """ + # if use_sparse_format: + # if sparse_matrix is None: + # raise ValueError("sparse_matrix must be provided when use_sparse_format=True") + # return TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*projection_matrix_multiplier, device) + # else: + # return TokenAligner.project_token_likelihoods_dense(input_likelihoods, projection_map_indices, projection_map_values*projection_matrix_multiplier, target_vocab_size, device, use_vectorized, gpu_optimized_scatter=gpu_optimized_scatter, enable_scale_trick=None) + + def project_token_likelihoods_ultra_fast(self, input_likelihoods, sparse_matrix=None, target_vocab_reduced_indices=None): + """ + Ultra-fast projection optimized for sparse matrices and reduced vocabularies. + + Args: + input_likelihoods: Input probabilities (B, S, V_student) + sparse_matrix: Sparse transformation matrix + target_vocab_reduced_indices: If provided, only project to these teacher vocab positions + """ + if sparse_matrix is None: + sparse_matrix = self.sparse_transformation_matrix + + if sparse_matrix is None: + raise ValueError("No sparse matrix available for ultra-fast projection") + + # Cache CSR conversion for repeated use + if not hasattr(self, '_sparse_csr_cache') or self._sparse_csr_cache.get('matrix_id') != id(sparse_matrix): + sparse_csr = sparse_matrix.to_sparse_csr() + self._sparse_csr_cache = { + 'matrix_id': id(sparse_matrix), + 'csr_matrix': sparse_csr + } + else: + sparse_csr = self._sparse_csr_cache['csr_matrix'] + + # Ultra-fast sparse matmul with shape optimization + bsz, seqlen, vs = input_likelihoods.shape + x2d = input_likelihoods.reshape(bsz * seqlen, vs) + + # Use optimized sparse matmul (often faster than dense) + out2d = torch.sparse.mm(sparse_csr.t(), x2d.t()).t() + + result = out2d.reshape(bsz, seqlen, -1) + + # If target vocab is reduced, slice early + if target_vocab_reduced_indices is not None: + result = result[:, :, target_vocab_reduced_indices] + + return result + + def project_token_likelihoods_instance(self, input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_sparse_format=False, sparse_matrix=None, use_vectorized=True, gpu_optimized_scatter=True, global_top_indices=None): + """ + Instance method wrapper for project_token_likelihoods that can access instance variables. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + """ + if use_sparse_format: + if sparse_matrix is None: + raise ValueError("sparse_matrix must be provided when use_sparse_format=True") + + if global_top_indices is not None: + # For sparse format with global_top_indices, project to full vocab then slice + full_projection = TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + return full_projection[:, :, global_top_indices] + else: + return TokenAligner.project_token_likelihoods_sparse(input_likelihoods, sparse_matrix*self.projection_matrix_multiplier, device) + else: + # If projection map is learnable, fall back to dense scatter path to preserve gradients + if getattr(projection_map_values, "requires_grad", False): + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.project_token_likelihoods_dense( + input_likelihoods, + projection_map_indices, + projection_map_values * self.projection_matrix_multiplier, + target_vocab_size, + device, + use_vectorized=True, + gpu_optimized_scatter=gpu_optimized_scatter, + enable_scale_trick=scale_trick_enabled, + global_top_indices=global_top_indices, + ) + + # Otherwise, use stateless CSR matmul (no caching) for memory efficiency + vs = projection_map_indices.shape[0] + top_k = projection_map_indices.shape[1] + # Ensure device/dtype for indices/values + idx = projection_map_indices.to(device) + val = (projection_map_values * self.projection_matrix_multiplier).to(device) + if val.dtype != input_likelihoods.dtype: + val = val.to(input_likelihoods.dtype) + # Build CSR once per call outside autograd to keep checkpoint recomputation identical + with torch.no_grad(): + crow_indices = torch.arange(0, (vs + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = idx.reshape(-1) + values = val.reshape(-1) + proj_csr = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(vs, target_vocab_size), device=device + ) + # Matmul: [B, S, Vs] -> [B*S, Vs] @ [Vs, Vt] -> [B*S, Vt] -> [B, S, Vt] + bsz, seqlen, vs_in = input_likelihoods.shape + if vs_in != vs: + # In case logits have extra vocab tail, slice to match + x = input_likelihoods[:, :, :vs] + else: + x = input_likelihoods + x2d = x.reshape(bsz * seqlen, vs) + out2d = torch.matmul(x2d.to(torch.float32), proj_csr.to(torch.float32)) + out = out2d.reshape(bsz, seqlen, target_vocab_size).to(input_likelihoods.dtype) + return out + + @staticmethod + def project_token_likelihoods_dense(input_likelihoods, projection_map_indices, projection_map_values, target_vocab_size, device, use_vectorized=True, gpu_optimized_scatter=True, enable_scale_trick=None, global_top_indices=None): + """ + Projects token likelihoods from a source to a target vocabulary using dense indices/values format. + + Args: + global_top_indices: Optional tensor of shape (K,) containing indices of target tokens to project to. + If provided, only projects to these K tokens instead of full target_vocab_size. + Results in (batch, seq, K) output instead of (batch, seq, target_vocab_size). + MAJOR SPEEDUP: Reduces both memory and compute significantly. + """ + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + if abs(source_vocab_size - projection_map_indices.shape[0]) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches projection map size ({projection_map_indices.shape[0]})") + + top_k = projection_map_indices.shape[1] + input_likelihoods = input_likelihoods.to(device) + if projection_map_indices.device != device: + projection_map_indices = projection_map_indices.to(device) + if projection_map_values.device != device: + projection_map_values = projection_map_values.to(device) + #do for dtype + if projection_map_values.dtype != input_likelihoods.dtype: + projection_map_values = projection_map_values.to(input_likelihoods.dtype) + + # else: + # projection_map_values = projection_map_values.to(device) + + if use_vectorized: + # Solution 1: Efficient dense implementation using vectorized operations for small top_k + source_vocab_size_fixed = projection_map_indices.shape[0] + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + + # OPTIMIZATION: Use reduced vocabulary if global_top_indices provided + if global_top_indices is not None: + k_indices = len(global_top_indices) + global_top_indices = global_top_indices.to(device) + + # Create mapping from full target indices to reduced indices [0, 1, 2, ..., k-1] + full_to_reduced_map = torch.full((target_vocab_size,), -1, device=device, dtype=torch.long) + full_to_reduced_map[global_top_indices] = torch.arange(k_indices, device=device) + + # Initialize smaller output tensor - MAJOR MEMORY SAVINGS + projected_likelihoods = torch.zeros(batch_size, seq_len, k_indices, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = k_indices + + # Filter projection matrices to only include mappings to global_top_indices + # This will be used in the scatter operations below + use_reduced_projection = True + else: + # Initialize full output tensor + projected_likelihoods = torch.zeros(batch_size, seq_len, target_vocab_size, + device=device, dtype=input_likelihoods.dtype) + effective_vocab_size = target_vocab_size + use_reduced_projection = False + + # Optimized chunked processing with multiple speedup techniques + # Use larger chunks for better amortization of fixed costs + max_memory_mb = 200 # Increased for better performance + # max_memory_mb = 500 # Increased for better performance + elements_per_chunk = max_memory_mb * 1024 * 1024 // 4 # 4 bytes per float32 + chunk_size = max(512, min(source_vocab_size_fixed, elements_per_chunk // (batch_size * seq_len))) + + + use_masking = False + # Process vocabulary in optimized chunks + for chunk_start in range(0, source_vocab_size_fixed, chunk_size): + chunk_end = min(chunk_start + chunk_size, source_vocab_size_fixed) + chunk_len = chunk_end - chunk_start + + + input_chunk = input_likelihoods_fixed[:, :, chunk_start:chunk_end] # (B, S, chunk_len) + indices_chunk = projection_map_indices[chunk_start:chunk_end, :] # (chunk_len, top_k) + values_chunk = projection_map_values[chunk_start:chunk_end, :] # (chunk_len, top_k) + + # Extract input chunk once per chunk (not per k) - major speedup + # Determine effective top_k (exclude last column if scale trick is enabled) + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + effective_top_k = top_k - 1 if scale_trick_enabled else top_k + # effective_top_k = 1 + + if gpu_optimized_scatter: + if use_masking: + # Process one k at a time to reduce peak memory usage + for k in range(effective_top_k): + values_k = values_chunk[:, k] + valid_mask_k = values_k > 1e-4 + if not valid_mask_k.any(): + continue + + source_indices_k = torch.nonzero(valid_mask_k, as_tuple=True)[0] + + input_subset_k = input_chunk[:, :, source_indices_k] + values_subset_k = values_k[source_indices_k] + + indices_k = indices_chunk[:, k] + target_indices_subset_k = indices_k[source_indices_k] + + weighted_inputs_k = input_subset_k * values_subset_k.view(1, 1, -1) + expanded_target_indices_k = target_indices_subset_k.view(1, 1, -1).expand(batch_size, seq_len, -1) + + projected_likelihoods.scatter_add_(2, expanded_target_indices_k, weighted_inputs_k) + else: + # Compact, un-masked implementation + # Process only effective columns without creating intermediate tensors + input_expanded = input_chunk.unsqueeze(-1) # (B, S, chunk_len, 1) + + for k in range(effective_top_k): + values_k = values_chunk[:, k:k+1] # (chunk_len, 1) - view, no copy + indices_k = indices_chunk[:, k] # (chunk_len,) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + # Map full indices to reduced indices and filter out invalid ones + reduced_indices_k = full_to_reduced_map[indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 # Only keep indices in global_top_indices + + if not valid_mask.any(): + continue # Skip if no valid indices in this chunk + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + values_filtered = values_k.squeeze(-1)[valid_indices] # (valid_count,) + input_filtered = input_chunk[:, :, valid_indices] # (B, S, valid_count) + + weighted_k = input_filtered * values_filtered.unsqueeze(0).unsqueeze(0) + indices_expanded = reduced_indices_filtered.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Standard full projection + weighted_k = input_expanded * values_k.unsqueeze(0).unsqueeze(0) # (B, S, chunk_len, 1) + weighted_k = weighted_k.squeeze(-1) # (B, S, chunk_len) + + indices_expanded = indices_k.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1) + projected_likelihoods.scatter_add_(2, indices_expanded, weighted_k) + else: + # Original implementation with a loop over top_k + if True: # For small top_k, process all k together + # Broadcast input: (B, S, chunk_len, 1) * (1, 1, chunk_len, top_k) -> (B, S, chunk_len, top_k) + weighted_inputs = input_chunk.unsqueeze(-1) * values_chunk.unsqueeze(0).unsqueeze(0) + + # Process all k simultaneously using advanced indexing + for k in range(effective_top_k): + target_indices_k = indices_chunk[:, k] # (chunk_len,) + weighted_k = weighted_inputs[:, :, :, k] # (B, S, chunk_len) + + if use_reduced_projection: + # OPTIMIZATION: Only project to indices in global_top_indices + reduced_indices_k = full_to_reduced_map[target_indices_k] # (chunk_len,) + valid_mask = reduced_indices_k != -1 + + if not valid_mask.any(): + continue # Skip if no valid indices + + # Filter to only valid entries - MAJOR COMPUTE SAVINGS + valid_indices = torch.nonzero(valid_mask, as_tuple=True)[0] + reduced_indices_filtered = reduced_indices_k[valid_indices] + weighted_filtered = weighted_k[:, :, valid_indices] # (B, S, valid_count) + + target_expanded = reduced_indices_filtered.view(1, 1, -1).expand(batch_size, seq_len, len(valid_indices)) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_filtered) + else: + # Use optimized scatter with pre-expanded indices (avoid .expand() in loop) + target_expanded = target_indices_k.view(1, 1, -1).expand(batch_size, seq_len, chunk_len) + projected_likelihoods.scatter_add_(2, target_expanded, weighted_k) + + # else: # For larger top_k, use optimized sequential processing + # for k in range(top_k): + # target_indices_k = indices_chunk[:, k] # (chunk_len,) + # target_values_k = values_chunk[:, k] # (chunk_len,) + + # # Skip projections marked with -1 + # valid_mask = target_values_k > -0.00001 + # if not valid_mask.any(): + # continue + + # # Only process valid projections + # valid_target_indices = target_indices_k[valid_mask] + # valid_target_values = target_values_k[valid_mask] + # valid_input = input_chunk[valid_mask] + + # weighted_input = valid_input * valid_target_values.view(-1, 1, 1) + + # # Direct scatter (simpler and often faster than index caching) + # target_expanded = valid_target_indices.view(1, 1, -1).expand(batch_size, seq_len, valid_target_indices.size(0)) + # projected_likelihoods.scatter_add_(2, target_expanded, weighted_input) + + return projected_likelihoods + else: + # Solution 2: Sparse matrix approach (original implementation) + source_vocab_size_fixed = projection_map_indices.shape[0] + + # Create sparse CSR matrix + crow_indices = torch.arange(0, (source_vocab_size_fixed + 1) * top_k, top_k, device=device, dtype=torch.long) + col_indices = projection_map_indices.flatten() + values = projection_map_values.flatten() + + sparse_projection_matrix = torch.sparse_csr_tensor( + crow_indices, col_indices, values, size=(source_vocab_size_fixed, target_vocab_size), device=device + ) + + # Apply sparse matrix multiplication + input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_projection_matrix.to(torch.float32)) + + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, target_vocab_size).to(input_likelihoods.dtype) + + @staticmethod + def project_token_likelihoods_sparse(input_likelihoods, sparse_matrix, device): + """Projects token likelihoods using a sparse transformation matrix.""" + batch_size, seq_len, source_vocab_size = input_likelihoods.shape + + # Get dimensions from sparse matrix + matrix_input_size, matrix_output_size = sparse_matrix.shape + + if abs(source_vocab_size - matrix_input_size) > 1000: + raise ValueError(f"Source vocab size of input ({source_vocab_size}) mismatches sparse matrix input size ({matrix_input_size})") + + # Move to correct device and dtype + # input_likelihoods = input_likelihoods.to(device) + # sparse_matrix = sparse_matrix.to(device) + + # Adjust input size to match matrix dimensions + # next 2 lines required when we used vocab length from tokenizer, now we use the size of logits + # source_vocab_size_fixed = min(source_vocab_size, matrix_input_size) + # input_likelihoods_fixed = input_likelihoods[:, :, :source_vocab_size_fixed] + input_likelihoods_fixed = input_likelihoods + + # Reshape for matrix multiplication + reshaped_input = input_likelihoods_fixed.reshape(batch_size * seq_len, source_vocab_size) + + # Project using sparse matrix multiplication + projected_likelihoods_reshaped = torch.matmul(reshaped_input.to(torch.float32), sparse_matrix.to(torch.float32)) + + # Reshape back to original format + return projected_likelihoods_reshaped.reshape(batch_size, seq_len, matrix_output_size).to(input_likelihoods.dtype) + + def align(self, student_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + teacher_seq: Union[List[str], List[List[str]], List[int], List[List[int]]], + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=False, + chunk_size=128, + post_process=True, + convert_ids_to_tokens=True, + anchor_lengths=[3,], + track_rules=None, + _debug_timing=False): + """ + Aligns two sequences of tokens (or batches of sequences), identifies + translation rules, and updates the internal rule set. + """ + import time as _time + + should_track_rules = track_rules if track_rules is not None else self.track_rules + + seq1 = student_seq + seq2 = teacher_seq + + original_seq1_ids = None + original_seq2_ids = None + + _t_convert = 0.0 + if isinstance(seq1, torch.Tensor): + original_seq1_ids = seq1.cpu().tolist() + original_seq2_ids = seq2.cpu().tolist() + + seq1 = seq1.cpu().tolist() + seq2 = seq2.cpu().tolist() + if convert_ids_to_tokens: + _tc0 = _time.time() + seq1 = [self.student_tokenizer.convert_ids_to_tokens(seq1_single) for seq1_single in seq1] + seq2 = [self.teacher_tokenizer.convert_ids_to_tokens(seq2_single) for seq2_single in seq2] + _t_convert = _time.time() - _tc0 + + is_batched = isinstance(seq1, list) and len(seq1) > 0 and isinstance(seq1[0], list) + + _t_canon_total = 0.0 + _t_anchors_dp_total = 0.0 + _t_postprocess_total = 0.0 + _t_mask_total = 0.0 + + if is_batched: + if not (isinstance(seq2, list) and len(seq2) == len(seq1) and (len(seq2) == 0 or isinstance(seq2[0], list))): + raise ValueError("For batched input, seq1 and seq2 must be lists of lists with the same length.") + + all_aligned_pairs = [] + for i, (s1, s2) in enumerate(zip(seq1, seq2)): + s1_ids = original_seq1_ids[i] if original_seq1_ids else None + s2_ids = original_seq2_ids[i] if original_seq2_ids else None + aligned_pairs, timings = self._align_single(s1, s2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, s1_ids, s2_ids, should_track_rules, _return_timings=True) + all_aligned_pairs.append(aligned_pairs) + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + else: + s1_ids = original_seq1_ids[0] if original_seq1_ids else None + s2_ids = original_seq2_ids[0] if original_seq2_ids else None + aligned_pairs, timings = self._align_single(seq1, seq2, exact_match_score, combination_score_multiplier, gap_penalty, ignore_leading_char_diff, chunk_size, post_process, anchor_lengths, s1_ids, s2_ids, should_track_rules, _return_timings=True) + all_aligned_pairs = [aligned_pairs] + _t_canon_total += timings.get("canon", 0) + _t_anchors_dp_total += timings.get("anchors_dp", 0) + _t_postprocess_total += timings.get("postprocess", 0) + _t_mask_total += timings.get("mask", 0) + + if _debug_timing: + n = len(all_aligned_pairs) + print(f" [align timing] convert_ids={_t_convert:.3f}s, " + f"canonicalize={_t_canon_total:.3f}s, " + f"anchors+DP={_t_anchors_dp_total:.3f}s, " + f"postprocess={_t_postprocess_total:.3f}s, " + f"mask={_t_mask_total:.3f}s " + f"(n={n})", flush=True) + + return all_aligned_pairs + + def _align_single(self, seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + ignore_leading_char_diff=True, + chunk_size=0, + post_process=True, + anchor_lengths=None, + seq1_token_ids=None, + seq2_token_ids=None, + track_rules=None, + _return_timings=False): + """ + Aligns two sequences of tokens, identifies translation rules, and updates + the internal rule set. + """ + import time as _time + + _tc0 = _time.time() + seq1_canon = TokenAligner._canonicalize_sequence(seq1) + seq2_canon = TokenAligner._canonicalize_sequence(seq2) + _tc1 = _time.time() + + align_kwargs = { + 'exact_match_score': exact_match_score, + 'combination_score_multiplier': combination_score_multiplier, + 'gap_penalty': gap_penalty, + 'max_combination_len': self.max_combination_len, + 'ignore_leading_char_diff': False, + 'chunk_size': chunk_size, + 'anchor_lengths': anchor_lengths, + } + + aligned_pairs, _ = self._align_with_anchors(seq1_canon, seq2_canon, **align_kwargs) + _tc2 = _time.time() + + if post_process: + aligned_pairs = self.post_process_alignment_optimized( + aligned_pairs, + ignore_leading_char_diff=ignore_leading_char_diff, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=self.max_combination_len + ) + _tc3 = _time.time() + + mask = self.get_alignment_mask(aligned_pairs, use_canonicalization=True, ignore_leading_char_diff=ignore_leading_char_diff) + aligned_pairs = [ + (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, mask_value) + for (s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end), mask_value in zip(aligned_pairs, mask) + ] + _tc4 = _time.time() + + if track_rules: + self._update_rules(aligned_pairs, seq1_token_ids, seq2_token_ids, seq1, seq2) + + timings = { + "canon": _tc1 - _tc0, + "anchors_dp": _tc2 - _tc1, + "postprocess": _tc3 - _tc2, + "mask": _tc4 - _tc3, + } + + if _return_timings: + return aligned_pairs, timings + return aligned_pairs + + + def compute_accuracy(self, aligned_pairs, ignore_student_ids=None, ignore_teacher_ids=None): + """ + Compute alignment accuracy from aligned pairs with support for batched input. + + Args: + aligned_pairs: Either a single list of aligned pairs or a list of lists (batched) + ignore_student_ids: Set of student token IDs to ignore when computing accuracy + ignore_teacher_ids: Set of teacher token IDs to ignore when computing accuracy + + Returns: + For single input: Single accuracy value (float) + For batched input: List of accuracy values (List[float]) + """ + if ignore_student_ids is None: + ignore_student_ids = set() + if ignore_teacher_ids is None: + ignore_teacher_ids = set() + + def is_not_ignored(token_or_tokens, ignore_ids): + """Check if token(s) should not be ignored in accuracy computation.""" + if isinstance(token_or_tokens, (list, tuple)): + return all(tok not in ignore_ids for tok in token_or_tokens) + else: + return token_or_tokens not in ignore_ids + + def compute_single_accuracy(single_aligned_pairs): + """Compute accuracy for a single sequence's aligned pairs.""" + if not single_aligned_pairs: + return 0.0 + + mask_values = [ + pair[6] # is_correct mask value + for pair in single_aligned_pairs + if is_not_ignored(pair[0], ignore_student_ids) and is_not_ignored(pair[1], ignore_teacher_ids) + ] + + if not mask_values: + return 0.0 + + return sum(mask_values) / float(len(mask_values)) + + # Check if input is batched (list of lists of aligned pairs) + # First check if it's a list and has elements + if isinstance(aligned_pairs, list) and len(aligned_pairs) > 0: + # Check if the first element is itself a list of tuples (indicating batched input) + if (isinstance(aligned_pairs[0], list) and len(aligned_pairs[0]) > 0 and + (isinstance(aligned_pairs[0][0], tuple) or isinstance(aligned_pairs[0][0], list)) and len(aligned_pairs[0][0]) == 7): + # Batched input: compute accuracy for each batch item + return [compute_single_accuracy(batch_pairs) for batch_pairs in aligned_pairs] + elif isinstance(aligned_pairs[0], tuple) and len(aligned_pairs[0]) == 7: + # Single sequence input + return compute_single_accuracy(aligned_pairs) + + # Empty or invalid input + return 0.0 + + def _align_with_anchors(self, seq1, seq2, anchor_lengths=[3,], **kwargs): + """ + Optimized alignment using unique 1-to-1 matches as anchors. + """ + # CRITICAL FIX: If anchor_lengths is empty, disable anchor optimization completely + if not anchor_lengths: + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + if anchor_lengths is None: + anchor_lengths = [3, 2] # Default: check 3-token, then 2-token sequences + + # Debug output + debug = kwargs.get('debug', False) + + # 1. Find high-confidence anchor points using unique token matches. + s1_counts = {} + for i, t in enumerate(seq1): + if t not in s1_counts: s1_counts[t] = [] + s1_counts[t].append(i) + + s2_counts = {} + for i, t in enumerate(seq2): + if t not in s2_counts: s2_counts[t] = [] + s2_counts[t].append(i) + + # Find potential anchors using consecutive token sequences + potential_anchors = [] + + # FIXED: Don't break early - collect anchors from all lengths and then choose the best + all_potential_anchors = [] + + # Check for anchors of different lengths + for anchor_len in anchor_lengths: + anchors_for_this_len = [] + + if anchor_len == 1: + # Handle single token anchors + common_tokens = s1_counts.keys() & s2_counts.keys() + for token in common_tokens: + if len(s1_counts[token]) == 1 and len(s2_counts[token]) == 1: + i = s1_counts[token][0] + j = s2_counts[token][0] + anchors_for_this_len.append((i, j, anchor_len)) + else: + # Handle multi-token anchors + s1_ngram_counts = {} + for i in range(len(seq1) - anchor_len + 1): + ngram = tuple(seq1[i:i + anchor_len]) + if ngram not in s1_ngram_counts: + s1_ngram_counts[ngram] = [] + s1_ngram_counts[ngram].append(i) + + s2_ngram_counts = {} + for i in range(len(seq2) - anchor_len + 1): + ngram = tuple(seq2[i:i + anchor_len]) + if ngram not in s2_ngram_counts: + s2_ngram_counts[ngram] = [] + s2_ngram_counts[ngram].append(i) + + # Find n-grams that appear exactly once in both sequences + common_ngrams = s1_ngram_counts.keys() & s2_ngram_counts.keys() + for ngram in common_ngrams: + if len(s1_ngram_counts[ngram]) == 1 and len(s2_ngram_counts[ngram]) == 1: + i = s1_ngram_counts[ngram][0] + j = s2_ngram_counts[ngram][0] + # ADDED: Verify the anchor is actually correct + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + anchors_for_this_len.append((i, j, anchor_len)) + + all_potential_anchors.extend(anchors_for_this_len) + + # IMPROVED: Choose the best set of anchors + # Prefer longer anchors, but if shorter anchors give better coverage, use them + + # Sort by position and filter for monotonic ordering + all_potential_anchors.sort() + + # IMPROVED: Better anchor selection - use greedy approach to maximize coverage + selected_anchors = [] + used_positions_seq1 = set() + used_positions_seq2 = set() + + # Sort by anchor length (descending) then by position + all_potential_anchors.sort(key=lambda x: (-x[2], x[0], x[1])) + + for i, j, anchor_len in all_potential_anchors: + # Check if this anchor conflicts with already selected ones + seq1_range = set(range(i, i + anchor_len)) + seq2_range = set(range(j, j + anchor_len)) + + if not (seq1_range & used_positions_seq1) and not (seq2_range & used_positions_seq2): + # This anchor doesn't conflict - we can use it + selected_anchors.append((i, j, anchor_len)) + used_positions_seq1.update(seq1_range) + used_positions_seq2.update(seq2_range) + + # Re-sort selected anchors by position for processing + selected_anchors.sort() + + # IMPROVED: Additional validation of selected anchors + validated_anchors = [] + last_j = -1 + for i, j, anchor_len in selected_anchors: + # Ensure monotonic ordering and no overlaps + if j > last_j: + # Double-check the anchor is valid + if (i + anchor_len <= len(seq1) and j + anchor_len <= len(seq2) and + seq1[i:i + anchor_len] == seq2[j:j + anchor_len]): + validated_anchors.append((i, j, anchor_len)) + last_j = j + anchor_len - 1 + + anchors = validated_anchors + + if not anchors: + # If no anchors are found, fall back to the standard alignment. + return self._perform_dp_alignment(seq1, seq2, **kwargs) + + # 2. Align segments between anchors. + full_alignment = [] + last_i, last_j = 0, 0 + + for anchor_idx, (i, j, anchor_len) in enumerate(anchors): + + # Align segment before the current anchor. + seg1, seg2 = seq1[last_i:i], seq2[last_j:j] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + # Adjust indices to be relative to the full sequence and split exact matches. + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Add the anchor itself (consecutive tokens), also split if needed. + anchor_seq1 = seq1[i:i + anchor_len] + anchor_seq2 = seq2[j:j + anchor_len] + + # Split anchor into individual matches since they should be identical + for k in range(anchor_len): + full_alignment.append(( + [anchor_seq1[k]], [anchor_seq2[k]], + i + k, i + k + 1, + j + k, j + k + 1 + )) + + last_i, last_j = i + anchor_len, j + anchor_len + + # 3. Align the final segment after the last anchor. + seg1, seg2 = seq1[last_i:], seq2[last_j:] + + if seg1 or seg2: + aligned_segment, _ = self._perform_dp_alignment(seg1, seg2, **kwargs) + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in aligned_segment: + new_s1_start = s1_start + last_i if s1_start != -1 else -1 + new_s1_end = s1_end + last_i if s1_end != -1 else -1 + new_s2_start = s2_start + last_j if s2_start != -1 else -1 + new_s2_end = s2_end + last_j if s2_end != -1 else -1 + + # Split if both sides have the same tokens + if (len(s1_toks) > 1 and len(s2_toks) > 1 and + len(s1_toks) == len(s2_toks) and s1_toks == s2_toks): + # Split into individual 1-to-1 matches + for k in range(len(s1_toks)): + full_alignment.append(( + [s1_toks[k]], [s2_toks[k]], + new_s1_start + k, new_s1_start + k + 1, + new_s2_start + k, new_s2_start + k + 1 + )) + else: + full_alignment.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + return full_alignment, 0 # Return 0 for score as it's not well-defined here + + def _perform_dp_alignment(self, seq1, seq2, **kwargs): + """ + Helper function to run the core DP-based alignment. + """ + chunk_size = kwargs.get('chunk_size', 0) + kwargs.pop('chunk_size', None) + kwargs.pop('anchor_lengths', None) + + if chunk_size > 0: + return self.align_tokens_combinations_chunked(seq1, seq2, chunk_size=chunk_size, **kwargs) + else: + return self.align_tokens_with_combinations_numpy_jit(seq1, seq2, **kwargs) + + @staticmethod + def _align_chunked_fast( + seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False, + chunk_size=256, + ): + """Chunked processing using the fast DP as the base case.""" + n1, n2 = len(seq1), len(seq2) + + if n1 <= chunk_size and n2 <= chunk_size: + return TokenAligner.align_tokens_with_combinations_numpy_fast( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + ) + + mid1, mid2 = n1 // 2, n2 // 2 + + left_aligned, left_score = TokenAligner._align_chunked_fast( + seq1[:mid1], seq2[:mid2], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, chunk_size, + ) + right_aligned, right_score = TokenAligner._align_chunked_fast( + seq1[mid1:], seq2[mid2:], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, chunk_size, + ) + + adjusted_right = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end in right_aligned: + adjusted_right.append(( + s1_tokens, s2_tokens, + s1_start + mid1 if s1_start >= 0 else -1, + s1_end + mid1 if s1_end >= 0 else -1, + s2_start + mid2 if s2_start >= 0 else -1, + s2_end + mid2 if s2_end >= 0 else -1, + )) + + return left_aligned + adjusted_right, left_score + right_score + + @staticmethod + def _canonical_token(token: str) -> str: + """Return a canonical representation of a tokenizer token.""" + if not token: + return token + + # 1. Normalize space prefixes first + if token.startswith(' '): + token = 'Ġ' + token[1:] + elif token.startswith('_'): + token = 'Ġ' + token[1:] + elif token.startswith('▁'): # SentencePiece-style space prefix + token = 'Ġ' + token[1:] + + # 1.5. Normalize newline and whitespace representations + if token == 'Ċ': # GPT-style newline (used by Llama) + token = '\n' + elif token == '\\n': # Escaped newline representation + token = '\n' + elif token == 'ĉ': # Alternative newline representation + token = '\n' + elif token == 'Ġ\n': # Space + newline combination + token = '\n' + elif 'Ċ' in token: # Handle Ċ embedded in other tokens + token = token.replace('Ċ', '\n') + elif '\\n' in token: # Handle escaped newlines in compound tokens + token = token.replace('\\n', '\n') + + # 1.6. Handle space-separated punctuation normalization + if token == 'Ġ,': # Space + comma + token = ',' + elif token == 'Ġ.': # Space + period + token = '.' + elif token == 'Ġ;': # Space + semicolon + token = ';' + elif token == 'Ġ:': # Space + colon + token = ':' + + # 2. Handle SentencePiece byte fallback tokens like <0x20> + if token.startswith('<0x') and token.endswith('>') and len(token) == 6: + try: + byte_val = int(token[3:5], 16) + if 0 <= byte_val <= 255: + return chr(byte_val) + except ValueError: + pass + + # 3. Normalize common Unicode encoding issues + unicode_fixes = { + # Spanish + 'ñ': 'ñ', 'á': 'á', 'é': 'é', 'í': 'í', 'ó': 'ó', 'ú': 'ú', + 'Ã': 'À', 'â': 'â', 'ç': 'ç', + # French + 'ç': 'ç', 'è': 'è', 'é': 'é', 'ë': 'ë', 'î': 'î', 'ô': 'ô', + 'ù': 'ù', 'û': 'û', 'ÿ': 'ÿ', + # Chinese (common encoding artifacts) + 'ä¸Ń': '中', 'æĸĩ': '文', 'æĹ¥æľ¬': '日本', 'èªŀ': '語', + # Russian + 'ÐłÑĥÑģ': 'Рус', 'Ñģкий': 'ский', + # Arabic + 'اÙĦعربÙĬØ©': 'العربية', + # Hindi + 'ह': 'ह', 'िà¤Ĥ': 'हिं', 'दà¥Ģ': 'दी', + # Mathematical symbols (common artifacts) + 'âĪij': '∑', 'âĪı': '∏', 'âĪĤ': '∂', 'âĪĩ': '∇', + 'âĪŀ': '∞', 'âĪļ': '√', 'âĪ«': '∫', 'âīĪ': '≈', + 'âīł': '≠', 'âī¤': '≤', 'âī¥': '≥', + } + + # Apply Unicode fixes + for broken, fixed in unicode_fixes.items(): + if broken in token: + token = token.replace(broken, fixed) + + # 4. Normalize special tokens + special_token_map = { + '<|begin_of_text|>': '', # Llama-style BOS token + '': '', # Standard BOS token + '': '', # Padding tokens → empty (will be handled by alignment) + '': ' ', # End tokens + '': ' ', # End tokens + } + + if token in special_token_map: + return special_token_map[token] + + return token + + @staticmethod + def _canonicalize_sequence(seq: List[str]) -> List[str]: + """Canonicalize every token in a sequence (list of str).""" + # First, handle multi-token encoding artifacts (before individual canonicalization) + merged_artifacts = TokenAligner._merge_encoding_artifacts(seq) + + # Then, canonicalize individual tokens + canon_tokens = [TokenAligner._canonical_token(tok) for tok in merged_artifacts] + + # Finally, merge consecutive byte tokens into proper Unicode characters + return TokenAligner._merge_consecutive_bytes(canon_tokens) + + @staticmethod + def _merge_encoding_artifacts(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent multi-token encoding artifacts.""" + if not tokens: + return tokens + + # Common multi-token encoding artifacts that should be merged + multi_token_fixes = [ + # Mathematical symbols split across tokens + (['ĠâĪ', 'ij'], ['Ġ∑']), # Sum symbol + (['âĪ', 'ij'], ['∑']), # Sum symbol (no space) + (['ĠâĪ', 'ı'], ['Ġ∏']), # Product symbol + (['âĪ', 'ı'], ['∏']), # Product symbol (no space) + (['ĠâĪ', 'Ĥ'], ['Ġ∂']), # Partial derivative + (['âĪ', 'Ĥ'], ['∂']), # Partial derivative (no space) + (['ĠâĪ', 'ĩ'], ['Ġ∇']), # Nabla/gradient + (['âĪ', 'ĩ'], ['∇']), # Nabla/gradient (no space) + (['ĠâĪ', 'ŀ'], ['Ġ∞']), # Infinity + (['âĪ', 'ŀ'], ['∞']), # Infinity (no space) + (['ĠâĪ', 'ļ'], ['Ġ√']), # Square root + (['âĪ', 'ļ'], ['√']), # Square root (no space) + (['ĠâĪ', '«'], ['Ġ∫']), # Integral + (['âĪ', '«'], ['∫']), # Integral (no space) + (['Ġâī', 'ł'], ['Ġ≠']), # Not equal + (['âī', 'ł'], ['≠']), # Not equal (no space) + # Other common multi-token artifacts + (['Ġä¸', 'Ń'], ['Ġ中']), # Chinese character + (['ä¸', 'Ń'], ['中']), # Chinese character (no space) + (['æĸ', 'ĩ'], ['文']), # Chinese character + (['Ġæĸ', 'ĩ'], ['Ġ文']), # Chinese character (with space) + ] + + result = [] + i = 0 + + while i < len(tokens): + # Check if current position matches any multi-token pattern + matched = False + + for pattern, replacement in multi_token_fixes: + pattern_len = len(pattern) + if i + pattern_len <= len(tokens): + # Check if the tokens match the pattern + if tokens[i:i+pattern_len] == pattern: + # Replace with the fixed version + result.extend(replacement) + i += pattern_len + matched = True + break + + if not matched: + # No pattern matched, keep the original token + result.append(tokens[i]) + i += 1 + + return result + + @staticmethod + def _merge_consecutive_bytes(tokens: List[str]) -> List[str]: + """Merge consecutive tokens that represent UTF-8 byte sequences.""" + if not tokens: + return tokens + + result = [] + byte_buffer = [] + + for token in tokens: + # Check if this token represents byte(s) + clean_token = token.lstrip('Ġ') + + # Check if all characters in the token are visual bytes + all_chars_are_bytes = True + if len(clean_token) == 0: + all_chars_are_bytes = False + else: + for char in clean_token: + if TokenAligner._get_byte_value(char) is None: + all_chars_are_bytes = False + break + + if all_chars_are_bytes: + byte_buffer.append(token) + else: + # Not a byte token, flush buffer first + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + byte_buffer = [] + result.append(token) + + # Flush any remaining bytes + if byte_buffer: + merged = TokenAligner._try_merge_byte_buffer(byte_buffer) + result.extend(merged) + + return result + + @staticmethod + def _try_merge_byte_buffer(byte_tokens: List[str]) -> List[str]: + """Try to merge a buffer of potential byte tokens into a Unicode character.""" + if not byte_tokens: + return [] + + # If only one token, just return it unless it's a multi-character byte token + if len(byte_tokens) == 1: + token = byte_tokens[0] + clean_token = token.lstrip('Ġ') + if len(clean_token) <= 1: + return byte_tokens + # Continue processing multi-character token + + # Extract space prefix from first token + first_token = byte_tokens[0] + space_prefix = 'Ġ' if first_token.startswith('Ġ') else '' + + # Extract raw bytes from all characters in all tokens + raw_bytes = [] + for token in byte_tokens: + clean_token = token.lstrip('Ġ') + for char in clean_token: + byte_value = TokenAligner._get_byte_value(char) + if byte_value is not None: + raw_bytes.append(byte_value) + else: + # If any character is not a byte, return original tokens + return byte_tokens + + # Only try to merge if we have 2-4 bytes (typical for emoji/multi-byte chars) + if len(raw_bytes) < 2 or len(raw_bytes) > 4: + return byte_tokens + + # Try to decode as UTF-8 + try: + decoded_text = bytes(raw_bytes).decode('utf-8') + # Only merge if the result is a single Unicode character (like an emoji) + if len(decoded_text) == 1 and ord(decoded_text) > 127: + return [space_prefix + decoded_text] + else: + # If it's not a single special character, keep original tokens + return byte_tokens + except UnicodeDecodeError: + # If decoding fails, return original tokens + return byte_tokens + + # Common visual byte representations used by some tokenizers (especially for emojis) + VISUAL_BYTE_MAP = { + # Common emoji byte range (240-255) + 'ð': 240, 'Ɩ': 241, 'Ɨ': 242, 'Ƙ': 243, 'ƙ': 244, 'ƚ': 245, 'ƛ': 246, 'Ɯ': 247, + 'Ɲ': 248, 'ƞ': 249, 'Ɵ': 250, 'Ơ': 251, 'ơ': 252, 'Ƣ': 253, 'ƣ': 254, 'Ƥ': 255, + # Other common byte representations (0-255 only) + 'Ł': 156, 'ł': 157, 'Ń': 158, 'ń': 159, 'ĺ': 149, 'Ļ': 150, 'ļ': 151, 'Ľ': 152, + 'ľ': 153, 'Ŀ': 154, 'ŀ': 155, 'Ĭ': 135, 'ĭ': 136, 'Į': 137, 'į': 138, 'İ': 139, + 'ı': 140, 'IJ': 141, 'ij': 142, 'Ĵ': 143, 'ĵ': 144, 'Ķ': 145, 'ķ': 146, 'ĸ': 147, + 'Ĺ': 148, 'ĥ': 128, 'Ħ': 129, 'ħ': 130, 'Ĩ': 131, 'ĩ': 132, 'Ī': 133, 'ī': 134, + 'Ģ': 162, 'ģ': 163, 'Ĝ': 28, 'ĝ': 29, 'Ğ': 30, 'ğ': 31, + } + + @staticmethod + def _get_byte_value(token_char: str) -> int: + """Get the byte value for a character, handling both direct bytes and visual representations.""" + if len(token_char) != 1: + return None + + char_ord = ord(token_char) + + # Direct byte (0-255) + if char_ord < 256: + return char_ord + + # Visual byte representation + if token_char in TokenAligner.VISUAL_BYTE_MAP: + return TokenAligner.VISUAL_BYTE_MAP[token_char] + + return None + + @staticmethod + def _strings_equal_flexible(s1, s2, ignore_leading_char_diff): + if not ignore_leading_char_diff: + return s1 == s2 + + # Use our comprehensive canonicalization for robust comparison + s1_canonical = TokenAligner._canonical_token(s1) + s2_canonical = TokenAligner._canonical_token(s2) + + return s1_canonical == s2_canonical + + @staticmethod + def align_tokens_with_combinations_numpy_fast( + seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False, + band_width=None, + ): + """DP alignment using integer token IDs, int32 trace, and optional band constraint. + + Produces the same result as ``align_tokens_with_combinations_numpy`` but + replaces per-cell Python string comparisons with integer comparisons and + uses a compact int32 trace array instead of a Python object array. + + When *band_width* is set (recommended for cross-tokenizer alignment where + both sequences encode the same text), the DP is restricted to a diagonal + band of width ``2 * band_width + 1``, reducing complexity from + O(n1 * n2) to O(n1 * band_width). + + Note: ``ignore_leading_char_diff`` must be False. The caller + (``_align_single``) canonicalizes sequences before calling the DP, + so flexible comparison is never needed here. + """ + if ignore_leading_char_diff: + return TokenAligner.align_tokens_with_combinations_numpy( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + ) + n1, n2 = len(seq1), len(seq2) + if n1 == 0 and n2 == 0: + return [], 0.0 + if n1 == 0: + return [([], [seq2[j]], -1, -1, j, j + 1) for j in range(n2)], n2 * gap_penalty + if n2 == 0: + return [([seq1[i]], [], i, i + 1, -1, -1) for i in range(n1)], n1 * gap_penalty + + token_to_id: dict[str, int] = {} + _next = [0] + + def _id(s: str) -> int: + tid = token_to_id.get(s) + if tid is None: + tid = _next[0] + token_to_id[s] = tid + _next[0] += 1 + return tid + + ids1 = [_id(t) for t in seq1] + ids2 = [_id(t) for t in seq2] + + joined_ids1: dict[tuple[int, int], int] = {} + for i in range(n1 + 1): + for k in range(2, min(i, max_combination_len) + 1): + joined_ids1[(i - k, i)] = _id(''.join(seq1[i - k:i])) + + joined_ids2: dict[tuple[int, int], int] = {} + for j in range(n2 + 1): + for k in range(2, min(j, max_combination_len) + 1): + joined_ids2[(j - k, j)] = _id(''.join(seq2[j - k:j])) + + use_band = band_width is not None + NEG_INF = np.float32(-1e9) + dp = np.full((n1 + 1, n2 + 1), NEG_INF, dtype=np.float32) + # Trace codes: 0=start, 1=diag, 2=up, 3=left, + # 10+k = comb_s1_over_s2_k, 20+k = comb_s2_over_s1_k + trace = np.zeros((n1 + 1, n2 + 1), dtype=np.int32) + + dp[0, 0] = 0.0 + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 2 + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 3 + + scale = n2 / max(n1, 1) + exact = np.float32(exact_match_score) + neg_exact = np.float32(-exact_match_score) + gap = np.float32(gap_penalty) + comb_mul = np.float32(combination_score_multiplier) + + for i in range(1, n1 + 1): + if use_band: + ej = int(i * scale) + j_lo = max(1, ej - band_width) + j_hi = min(n2, ej + band_width) + else: + j_lo = 1 + j_hi = n2 + + id_i = ids1[i - 1] + + for j in range(j_lo, j_hi + 1): + id_j = ids2[j - 1] + + best = dp[i - 1, j - 1] + (exact if id_i == id_j else neg_exact) + best_m = 1 + + s = dp[i - 1, j] + gap + if s > best: + best = s + best_m = 2 + + s = dp[i, j - 1] + gap + if s > best: + best = s + best_m = 3 + + for k in range(2, min(j + 1, max_combination_len + 1)): + key = (j - k, j) + if key in joined_ids2 and id_i == joined_ids2[key]: + s = dp[i - 1, j - k] + comb_mul * k + if s > best: + best = s + best_m = 10 + k + + for k in range(2, min(i + 1, max_combination_len + 1)): + key = (i - k, i) + if key in joined_ids1 and id_j == joined_ids1[key]: + s = dp[i - k, j - 1] + comb_mul * k + if s > best: + best = s + best_m = 20 + k + + dp[i, j] = best + trace[i, j] = best_m + + aligned: list = [] + i, j = n1, n2 + while i > 0 or j > 0: + m = int(trace[i, j]) + if m == 1: + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1; j -= 1 + elif m == 2: + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif m == 3: + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif 10 <= m < 20: + k = m - 10 + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1; j -= k + elif 20 <= m < 30: + k = m - 20 + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k; j -= 1 + else: + break + + aligned.reverse() + return aligned, float(dp[n1, n2]) + + @staticmethod + def align_tokens_with_combinations_numpy(seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False): + n1, n2 = len(seq1), len(seq2) + dp = np.zeros((n1 + 1, n2 + 1), dtype=np.float32) + trace = np.full((n1 + 1, n2 + 1), '', dtype=object) + + # Initialize DP edges with gap penalties + for i in range(1, n1 + 1): + dp[i, 0] = dp[i - 1, 0] + gap_penalty + trace[i, 0] = 'up' + for j in range(1, n2 + 1): + dp[0, j] = dp[0, j - 1] + gap_penalty + trace[0, j] = 'left' + + # Precompute joined substrings for all valid k-length spans + joined_seq1 = {(i - k, i): ''.join(seq1[i - k:i]) + for i in range(n1 + 1) + for k in range(1, min(i, max_combination_len) + 1)} + joined_seq2 = {(j - k, j): ''.join(seq2[j - k:j]) + for j in range(n2 + 1) + for k in range(1, min(j, max_combination_len) + 1)} + + # Fill DP table + for i in range(1, n1 + 1): + for j in range(1, n2 + 1): + s1_val, s2_val = seq1[i - 1], seq2[j - 1] + match_score = exact_match_score if TokenAligner._strings_equal_flexible(s1_val, s2_val, ignore_leading_char_diff) else -exact_match_score + score_diag = dp[i - 1, j - 1] + match_score + score_up = dp[i - 1, j] + gap_penalty + score_left = dp[i, j - 1] + gap_penalty + + max_score = score_diag + best_move = 'diag' + if score_up > max_score: + max_score = score_up + best_move = 'up' + if score_left > max_score: + max_score = score_left + best_move = 'left' + + # Check for seq1[i-1] == join(seq2[j-k:j]) + for k in range(2, min(j + 1, max_combination_len + 1)): + if (j - k, j) in joined_seq2 and TokenAligner._strings_equal_flexible(s1_val, joined_seq2[(j - k, j)], ignore_leading_char_diff): + comb_score = dp[i - 1, j - k] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s1_over_s2_{k}' + + # Check for seq2[j-1] vs seq1[i-k:i] + for k in range(2, min(i + 1, max_combination_len + 1)): + if (i - k, i) in joined_seq1 and TokenAligner._strings_equal_flexible(s2_val, joined_seq1[(i - k, i)], ignore_leading_char_diff): + comb_score = dp[i - k, j - 1] + combination_score_multiplier * k + if comb_score > max_score: + max_score = comb_score + best_move = f'comb_s2_over_s1_{k}' + + dp[i, j] = max_score + trace[i, j] = best_move + + # Backtrack to extract alignment + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + move = trace[i, j] + if move == 'diag': + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif move == 'up': + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif move == 'left': + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif move.startswith('comb_s1_over_s2_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif move.startswith('comb_s2_over_s1_'): + k = int(move.rsplit('_', 1)[-1]) + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, dp[n1, n2] + + @staticmethod + def align_tokens_with_combinations_numpy_jit( + seq1, seq2, + exact_match_score=3, + combination_score_multiplier=1.5, + gap_penalty=-1.5, + max_combination_len=4, + ignore_leading_char_diff=False, + ): + """Numba-accelerated version of align_tokens_with_combinations_numpy. + + Pre-converts string tokens to integer IDs, runs the DP in a Numba + @njit kernel, then backtracks using the original string tokens. + Falls back to the pure-Python original when Numba is unavailable or + when ignore_leading_char_diff is True (requires Python string logic). + """ + if not _NUMBA_AVAILABLE or ignore_leading_char_diff: + return TokenAligner.align_tokens_with_combinations_numpy( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + ) + + n1, n2 = len(seq1), len(seq2) + if n1 == 0 and n2 == 0: + return [], 0.0 + if n1 == 0: + return [([], [seq2[j]], -1, -1, j, j + 1) for j in range(n2)], n2 * gap_penalty + if n2 == 0: + return [([seq1[i]], [], i, i + 1, -1, -1) for i in range(n1)], n1 * gap_penalty + + token_to_id: dict[str, int] = {} + _next_id = [0] + + def _get_id(s: str) -> int: + tid = token_to_id.get(s) + if tid is None: + tid = _next_id[0] + token_to_id[s] = tid + _next_id[0] += 1 + return tid + + ids1 = np.array([_get_id(t) for t in seq1], dtype=np.int64) + ids2 = np.array([_get_id(t) for t in seq2], dtype=np.int64) + + INVALID = np.int64(-1) + joined1 = np.full((n1 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for i in range(n1 + 1): + for k in range(2, min(i, max_combination_len) + 1): + joined1[i, k] = _get_id(''.join(seq1[i - k:i])) + + joined2 = np.full((n2 + 1, max_combination_len + 1), INVALID, dtype=np.int64) + for j in range(n2 + 1): + for k in range(2, min(j, max_combination_len) + 1): + joined2[j, k] = _get_id(''.join(seq2[j - k:j])) + + dp, trace = _dp_core_numba( + ids1, ids2, joined1, joined2, n1, n2, + np.float32(exact_match_score), + np.float32(gap_penalty), + np.float32(combination_score_multiplier), + max_combination_len, + ) + + aligned = [] + i, j = n1, n2 + while i > 0 or j > 0: + m = trace[i, j] + if m == 1: + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif m == 2: + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif m == 3: + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif 10 <= m < 20: + k = m - 10 + aligned.append(([seq1[i - 1]], seq2[j - k:j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif 20 <= m < 30: + k = m - 20 + aligned.append((seq1[i - k:i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned, float(dp[n1, n2]) + + @staticmethod + def align_tokens_combinations_chunked( + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + ignore_leading_char_diff: bool = False, + chunk_size: int = 256, + ): + """ + Chunked processing for very large sequences. + """ + n1, n2 = len(seq1), len(seq2) + + # If sequences are small enough, use regular algorithm + if n1 <= chunk_size and n2 <= chunk_size: + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + # For very large sequences, use divide-and-conquer approach + if n1 > chunk_size or n2 > chunk_size: + # Find approximate midpoint alignment using simplified algorithm + mid1, mid2 = n1 // 2, n2 // 2 + + # Recursively align left and right parts + left_aligned, left_score = TokenAligner.align_tokens_combinations_chunked( + seq1[:mid1], seq2[:mid2], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + right_aligned, right_score = TokenAligner.align_tokens_combinations_chunked( + seq1[mid1:], seq2[mid2:], exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff, + chunk_size=chunk_size + ) + + # Adjust indices for right part + adjusted_right = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end in right_aligned: + new_s1_start = s1_start + mid1 if s1_start >= 0 else -1 + new_s1_end = s1_end + mid1 if s1_end >= 0 else -1 + new_s2_start = s2_start + mid2 if s2_start >= 0 else -1 + new_s2_end = s2_end + mid2 if s2_end >= 0 else -1 + adjusted_right.append((s1_tokens, s2_tokens, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + # Combine results + combined_aligned = left_aligned + adjusted_right + combined_score = left_score + right_score + + return combined_aligned, combined_score + + # Fallback to regular algorithm + return TokenAligner.align_tokens_with_combinations_numpy_jit( + seq1, seq2, exact_match_score, combination_score_multiplier, + gap_penalty, max_combination_len, ignore_leading_char_diff + ) + + # @staticmethod + # def post_process_alignment_optimized( + # aligned_pairs: List, + # ignore_leading_char_diff: bool = False, + # exact_match_score: float = 3.0, + # combination_score_multiplier: float = 1.5, + # gap_penalty: float = -1.5, + # max_combination_len: int = 4 + # ) -> List: + # """ + # Optimized version of post_process_alignment with better performance. + # """ + # if not aligned_pairs: + # return [] + + # # Precompute joined strings for all pairs to avoid repeated concatenation + # # Use canonicalization for robust comparison + # pair_strings = [] + # for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + # # Canonicalize individual tokens before joining for better matching + # s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + # s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + # s1_str = "".join(s1_canonical_tokens) + # s2_str = "".join(s2_canonical_tokens) + # is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + # pair_strings.append((s1_str, s2_str, is_match)) + + # processed_pairs = [] + # alignment_cache = {} # Cache for repeated alignment patterns + # i = 0 + + # while i < len(aligned_pairs): + # s1_tokens, s2_tokens, *_ = aligned_pairs[i] + + # # Handle coarse alignments that can be split (optimized) + # if len(s1_tokens) > 1 and len(s1_tokens) == len(s2_tokens) and s1_tokens == s2_tokens: + # s1_start, s1_end, s2_start, s2_end = aligned_pairs[i][2:6] + # # Vectorized creation of split pairs + # for k in range(len(s1_tokens)): + # processed_pairs.append( + # ([s1_tokens[k]], [s2_tokens[k]], + # s1_start + k, s1_start + k + 1, + # s2_start + k, s2_start + k + 1) + # ) + # i += 1 + # continue + + # # Find bad regions more efficiently using precomputed strings + # start_bad_region = -1 + # for j in range(i, len(aligned_pairs)): + # if not pair_strings[j][2]: # is_match is False + # start_bad_region = j + # break + + # if start_bad_region == -1: + # # No more bad regions - add remaining pairs and exit + # processed_pairs.extend(aligned_pairs[i:]) + # break + + # # Add good pairs before bad region + # processed_pairs.extend(aligned_pairs[i:start_bad_region]) + + # # Optimized chunk processing with early termination + # found_fix = False + # max_chunk_size = min(10, len(aligned_pairs) - start_bad_region) # Limit search space + + # for chunk_size in range(2, max_chunk_size + 1): + # chunk = aligned_pairs[start_bad_region : start_bad_region + chunk_size] + + # # Efficient token extraction using list comprehension + # chunk_s1_tokens = [] + # chunk_s2_tokens = [] + # s1_indices = [] + # s2_indices = [] + + # for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end in chunk: + # chunk_s1_tokens.extend(s1_toks) + # chunk_s2_tokens.extend(s2_toks) + # if s1_toks: + # s1_indices.extend([s1_start, s1_end]) + # if s2_toks: + # s2_indices.extend([s2_start, s2_end]) + + # # Quick string comparison using canonicalization + # chunk_s1_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s1_tokens] + # chunk_s2_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s2_tokens] + # chunk_s1_str = "".join(chunk_s1_canonical) + # chunk_s2_str = "".join(chunk_s2_canonical) + + # if not TokenAligner._strings_equal_flexible(chunk_s1_str, chunk_s2_str, ignore_leading_char_diff): + # continue + + # # Create cache key for alignment + # cache_key = (tuple(chunk_s1_tokens), tuple(chunk_s2_tokens)) + + # if cache_key in alignment_cache: + # sub_aligned_pairs, realign_is_perfect = alignment_cache[cache_key] + # else: + # # Perform alignment + # sub_aligned_pairs, _ = TokenAligner.align_tokens_with_combinations_numpy( + # chunk_s1_tokens, + # chunk_s2_tokens, + # exact_match_score=exact_match_score, + # combination_score_multiplier=combination_score_multiplier, + # gap_penalty=gap_penalty, + # max_combination_len=max_combination_len, + # ignore_leading_char_diff=ignore_leading_char_diff + # ) + + # # Check if re-alignment was successful using canonicalization + # realign_is_perfect = all( + # TokenAligner._strings_equal_flexible( + # "".join([TokenAligner._canonical_token(tok) for tok in p[0]]), + # "".join([TokenAligner._canonical_token(tok) for tok in p[1]]), + # ignore_leading_char_diff + # ) + # for p in sub_aligned_pairs + # ) + + # # Cache the result + # alignment_cache[cache_key] = (sub_aligned_pairs, realign_is_perfect) + + # # Vectorized index calculations + # s1_chunk_start = min(s1_indices[::2]) if s1_indices else -1 + # s2_chunk_start = min(s2_indices[::2]) if s2_indices else -1 + + # if realign_is_perfect: + # # Add granular aligned pairs + # for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *_ in sub_aligned_pairs: + # new_s1_start = s1_chunk_start + s1_start if s1_start != -1 else -1 + # new_s1_end = s1_chunk_start + s1_end if s1_end != -1 else -1 + # new_s2_start = s2_chunk_start + s2_start if s2_start != -1 else -1 + # new_s2_end = s2_chunk_start + s2_end if s2_end != -1 else -1 + # processed_pairs.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + # else: + # # Create merged pair + # s1_chunk_end = max(s1_indices[1::2]) if s1_indices else -1 + # s2_chunk_end = max(s2_indices[1::2]) if s2_indices else -1 + # merged_pair = (chunk_s1_tokens, chunk_s2_tokens, s1_chunk_start, s1_chunk_end, s2_chunk_start, s2_chunk_end) + # processed_pairs.append(merged_pair) + + # i = start_bad_region + chunk_size + # found_fix = True + # break + + # if not found_fix: + # processed_pairs.append(aligned_pairs[start_bad_region]) + # i = start_bad_region + 1 + + # return processed_pairs + + @staticmethod + def _combine_consecutive_misaligned_tokens( + aligned_pairs: List, + pair_strings: List, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Combine consecutive misaligned tokens into single chunks to improve alignment. + + This addresses cases where multiple tokens are individually misaligned but + collectively represent the same content. Avoids combining tokens near the + end of sequences that might be misaligned due to length differences. + + Args: + aligned_pairs: List of alignment pairs + pair_strings: Precomputed string representations and match status + end_mismatch_threshold: Fraction of sequence from end to avoid chunking + + Returns: + Modified aligned_pairs with consecutive misaligned tokens combined + """ + if not aligned_pairs or len(aligned_pairs) < 2: + return aligned_pairs + + # Calculate the boundary for avoiding end mismatches + sequence_length = len(aligned_pairs) + end_boundary = int(sequence_length * (1 - end_mismatch_threshold)) + + processed_pairs = [] + i = 0 + + while i < len(aligned_pairs): + # Check if current pair is misaligned and not near the end + if (i < end_boundary and + not pair_strings[i][2] and # Current pair is misaligned + i + 1 < len(aligned_pairs)): # Not the last pair + + # Find consecutive misaligned pairs + consecutive_misaligned = [i] + j = i + 1 + + # Look ahead for more consecutive misaligned pairs (up to end boundary) + while (j < end_boundary and + j < len(aligned_pairs) and + not pair_strings[j][2]): # Next pair is also misaligned + consecutive_misaligned.append(j) + j += 1 + + # Only combine if we have multiple consecutive misaligned pairs + if len(consecutive_misaligned) >= 2: + # Combine all consecutive misaligned pairs into one chunk + combined_s1_tokens = [] + combined_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for idx in consecutive_misaligned: + s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest = aligned_pairs[idx] + combined_s1_tokens.extend(s1_tokens) + combined_s2_tokens.extend(s2_tokens) + + if s1_tokens and s1_start != -1: + s1_indices.extend([s1_start, s1_end]) + if s2_tokens and s2_start != -1: + s2_indices.extend([s2_start, s2_end]) + + # Calculate combined indices + combined_s1_start = min(s1_indices[::2]) if s1_indices else -1 + combined_s1_end = max(s1_indices[1::2]) if s1_indices else -1 + combined_s2_start = min(s2_indices[::2]) if s2_indices else -1 + combined_s2_end = max(s2_indices[1::2]) if s2_indices else -1 + + # Create combined pair + combined_pair = ( + combined_s1_tokens, + combined_s2_tokens, + combined_s1_start, + combined_s1_end, + combined_s2_start, + combined_s2_end + ) + + processed_pairs.append(combined_pair) + i = j # Skip to after the combined region + else: + # Only one misaligned pair, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + else: + # Current pair is aligned or near the end, keep as is + processed_pairs.append(aligned_pairs[i]) + i += 1 + + return processed_pairs + + + @staticmethod + def post_process_alignment_optimized( + aligned_pairs: List, + ignore_leading_char_diff: bool = False, + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + combine_misaligned_chunks: bool = True, + end_mismatch_threshold: float = 0.2 + ) -> List: + """ + Optimized version of post_process_alignment with better performance. + + Key optimizations: + 1. Precompute string concatenations to avoid repeated joins + 2. Early termination when no bad regions are found + 3. Cache alignment results for repeated chunk patterns + 4. Vectorized index calculations + 5. Reduced nested loop complexity + 6. Combine multiple consecutive misaligned tokens into single chunks + + Args: + combine_misaligned_chunks: If True, combine consecutive misaligned tokens into chunks + end_mismatch_threshold: Fraction of sequence length from end to avoid chunking (0.2 = last 20%) + """ + if not aligned_pairs: + return [] + + # Precompute joined strings for all pairs to avoid repeated concatenation + # Use canonicalization for robust comparison + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + # Canonicalize individual tokens before joining for better matching + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + # Step 1: Handle consecutive misaligned chunks if enabled + if combine_misaligned_chunks: + aligned_pairs = TokenAligner._combine_consecutive_misaligned_tokens( + aligned_pairs, pair_strings, end_mismatch_threshold + ) + + # Recompute pair_strings after combining misaligned chunks + pair_strings = [] + for i, (s1_tokens, s2_tokens, *rest) in enumerate(aligned_pairs): + s1_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s1_tokens] if s1_tokens else [] + s2_canonical_tokens = [TokenAligner._canonical_token(tok) for tok in s2_tokens] if s2_tokens else [] + s1_str = "".join(s1_canonical_tokens) + s2_str = "".join(s2_canonical_tokens) + is_match = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + pair_strings.append((s1_str, s2_str, is_match)) + + processed_pairs = [] + alignment_cache = {} # Cache for repeated alignment patterns + i = 0 + + while i < len(aligned_pairs): + s1_tokens, s2_tokens, *_ = aligned_pairs[i] + + # Handle coarse alignments that can be split (optimized) + if len(s1_tokens) > 1 and len(s1_tokens) == len(s2_tokens) and s1_tokens == s2_tokens: + s1_start, s1_end, s2_start, s2_end = aligned_pairs[i][2:6] + # Vectorized creation of split pairs + for k in range(len(s1_tokens)): + processed_pairs.append( + ([s1_tokens[k]], [s2_tokens[k]], + s1_start + k, s1_start + k + 1, + s2_start + k, s2_start + k + 1) + ) + i += 1 + continue + + # Find bad regions more efficiently using precomputed strings + start_bad_region = -1 + for j in range(i, len(aligned_pairs)): + if not pair_strings[j][2]: # is_match is False + start_bad_region = j + break + + if start_bad_region == -1: + # No more bad regions - add remaining pairs and exit + processed_pairs.extend(aligned_pairs[i:]) + break + + # Add good pairs before bad region + processed_pairs.extend(aligned_pairs[i:start_bad_region]) + + # Optimized chunk processing with early termination + found_fix = False + max_chunk_size = min(10, len(aligned_pairs) - start_bad_region) # Limit search space + + for chunk_size in range(2, max_chunk_size + 1): + chunk = aligned_pairs[start_bad_region : start_bad_region + chunk_size] + + # Efficient token extraction using list comprehension + chunk_s1_tokens = [] + chunk_s2_tokens = [] + s1_indices = [] + s2_indices = [] + + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *rest in chunk: + chunk_s1_tokens.extend(s1_toks) + chunk_s2_tokens.extend(s2_toks) + if s1_toks: + s1_indices.extend([s1_start, s1_end]) + if s2_toks: + s2_indices.extend([s2_start, s2_end]) + + # Quick string comparison using canonicalization + chunk_s1_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s1_tokens] + chunk_s2_canonical = [TokenAligner._canonical_token(tok) for tok in chunk_s2_tokens] + chunk_s1_str = "".join(chunk_s1_canonical) + chunk_s2_str = "".join(chunk_s2_canonical) + + if not TokenAligner._strings_equal_flexible(chunk_s1_str, chunk_s2_str, ignore_leading_char_diff): + continue + + # Create cache key for alignment + cache_key = (tuple(chunk_s1_tokens), tuple(chunk_s2_tokens)) + + if cache_key in alignment_cache: + sub_aligned_pairs, realign_is_perfect = alignment_cache[cache_key] + else: + # Perform alignment + sub_aligned_pairs, _ = TokenAligner.align_tokens_with_combinations_numpy( + chunk_s1_tokens, + chunk_s2_tokens, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + ignore_leading_char_diff=ignore_leading_char_diff + ) + + # Check if re-alignment was successful using canonicalization + realign_is_perfect = all( + TokenAligner._strings_equal_flexible( + "".join([TokenAligner._canonical_token(tok) for tok in p[0]]), + "".join([TokenAligner._canonical_token(tok) for tok in p[1]]), + ignore_leading_char_diff + ) + for p in sub_aligned_pairs + ) + + # Cache the result + alignment_cache[cache_key] = (sub_aligned_pairs, realign_is_perfect) + + # Vectorized index calculations + s1_chunk_start = min(s1_indices[::2]) if s1_indices else -1 + s2_chunk_start = min(s2_indices[::2]) if s2_indices else -1 + + if realign_is_perfect: + # Add granular aligned pairs + for s1_toks, s2_toks, s1_start, s1_end, s2_start, s2_end, *_ in sub_aligned_pairs: + new_s1_start = s1_chunk_start + s1_start if s1_start != -1 else -1 + new_s1_end = s1_chunk_start + s1_end if s1_end != -1 else -1 + new_s2_start = s2_chunk_start + s2_start if s2_start != -1 else -1 + new_s2_end = s2_chunk_start + s2_end if s2_end != -1 else -1 + processed_pairs.append((s1_toks, s2_toks, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + else: + # Create merged pair + s1_chunk_end = max(s1_indices[1::2]) if s1_indices else -1 + s2_chunk_end = max(s2_indices[1::2]) if s2_indices else -1 + merged_pair = (chunk_s1_tokens, chunk_s2_tokens, s1_chunk_start, s1_chunk_end, s2_chunk_start, s2_chunk_end) + processed_pairs.append(merged_pair) + + i = start_bad_region + chunk_size + found_fix = True + break + + if not found_fix: + processed_pairs.append(aligned_pairs[start_bad_region]) + i = start_bad_region + 1 + + return processed_pairs + + @staticmethod + def get_alignment_mask(aligned_pairs: List, use_canonicalization: bool = True, + ignore_leading_char_diff: bool = False) -> List[bool]: + """ + Get a boolean mask indicating which alignments are correct. + """ + if not aligned_pairs: + return [] + + # Handle batch case - take first batch + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + pairs_to_verify = aligned_pairs[0] + else: + pairs_to_verify = aligned_pairs + + mask = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *rest in pairs_to_verify: + # Concatenate tokens into strings + s1_str = "".join(s1_tokens) if s1_tokens else "" + s2_str = "".join(s2_tokens) if s2_tokens else "" + + # Apply canonicalization if requested + if use_canonicalization: + s1_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s1_tokens]) if s1_tokens else "" + s2_canonical = "".join([TokenAligner._canonical_token(tok) for tok in s2_tokens]) if s2_tokens else "" + is_correct = TokenAligner._strings_equal_flexible(s1_canonical, s2_canonical, ignore_leading_char_diff) + else: + if ignore_leading_char_diff: + is_correct = TokenAligner._strings_equal_flexible(s1_str, s2_str, ignore_leading_char_diff) + else: + is_correct = s1_str == s2_str + + mask.append(is_correct) + + return mask + + def _update_rules(self, aligned_pairs, student_token_ids=None, teacher_token_ids=None, student_sequence=None, teacher_sequence=None): + """Update rule tracking with aligned pairs""" + # Track how many times each rule is triggered + if not hasattr(self, "forward_rule_counts"): + self.forward_rule_counts = {} + if not hasattr(self, "reverse_rule_counts"): + self.reverse_rule_counts = {} + if not hasattr(self, "forward_rules_with_ids"): + self.forward_rules_with_ids = {} # Maps (token_strings) -> (token_ids, count) + if not hasattr(self, "reverse_rules_with_ids"): + self.reverse_rules_with_ids = {} # Maps (token_strings) -> (token_ids, count) + if not hasattr(self, "rule_conflicts"): + self.rule_conflicts = {} # Track conflicting rules: source -> set of conflicting targets + if not hasattr(self, "rule_conflict_counts"): + self.rule_conflict_counts = {} # Track counts: (source, target) -> count + if not hasattr(self, "conflict_contexts"): + self.conflict_contexts = {} # Track full context when conflicts occur: conflict_id -> context_data + + for s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end, *rest in aligned_pairs: + # Extract mask value if available, default to True for backward compatibility + is_correct = rest[0] if rest else True + + # Only add rules if the alignment is correct (mask is positive) + if not is_correct: + continue + + s1_tuple = tuple(s1_elems) + s2_tuple = tuple(s2_elems) + + if s1_tuple and s2_tuple: + # Extract token IDs if available + s1_ids = None + s2_ids = None + if student_token_ids is not None and s1_start != -1 and s1_end != -1: + s1_ids = tuple(student_token_ids[s1_start:s1_end]) + if teacher_token_ids is not None and s2_start != -1 and s2_end != -1: + s2_ids = tuple(teacher_token_ids[s2_start:s2_end]) + + # Check for conflicts in existing rules + existing_targets = [rule[1] for rule in self.forward_rules if rule[0] == s1_tuple] + if existing_targets and s2_tuple not in existing_targets: + # Initialize conflict tracking for this source if not exists + if s1_tuple not in self.rule_conflicts: + self.rule_conflicts[s1_tuple] = set() + + # Add all targets (existing + new) to the conflict set + all_targets = set(existing_targets + [s2_tuple]) + old_conflict_size = len(self.rule_conflicts[s1_tuple]) + self.rule_conflicts[s1_tuple].update(all_targets) + + # Store the full context of this conflict + if len(self.rule_conflicts[s1_tuple]) > old_conflict_size: + conflict_id = f"{hash((s1_tuple, s2_tuple, len(self.conflict_contexts)))}" + context_data = { + 'conflict_source': s1_tuple, + 'new_target': s2_tuple, + 'existing_targets': existing_targets, + 'student_sequence': student_sequence, + 'teacher_sequence': teacher_sequence, + 'student_token_ids': student_token_ids, + 'teacher_token_ids': teacher_token_ids, + 'full_alignment': aligned_pairs, + 'conflict_position': (s1_start, s1_end, s2_start, s2_end), + 'student_ids_at_conflict': s1_ids, + 'teacher_ids_at_conflict': s2_ids, + 'timestamp': __import__('time').time(), + } + self.conflict_contexts[conflict_id] = context_data + + # Store string-based rules (backward compatibility) + self.forward_rules.add((s1_tuple, s2_tuple)) + self.reverse_rules.add((s2_tuple, s1_tuple)) + + # Track conflict counts for this specific source-target pair + conflict_key = (s1_tuple, s2_tuple) + + # After adding to rules, check if this is now part of a conflict + # and update conflict counts accordingly + if s1_tuple in self.rule_conflicts and s2_tuple in self.rule_conflicts[s1_tuple]: + # This rule is part of a conflict, count it + self.rule_conflict_counts[conflict_key] = self.rule_conflict_counts.get(conflict_key, 0) + 1 + + # Also retroactively count any other conflicting rules for this source + # that we may have missed when they weren't conflicts yet + for target in self.rule_conflicts[s1_tuple]: + target_key = (s1_tuple, target) + if target_key != conflict_key and target_key not in self.rule_conflict_counts: + # Count how many times this rule has been used + rule_count = self.forward_rule_counts.get(target_key, 0) + if rule_count > 0: + self.rule_conflict_counts[target_key] = rule_count + + # Store rules with token IDs + forward_key = (s1_tuple, s2_tuple) + reverse_key = (s2_tuple, s1_tuple) + + if forward_key not in self.forward_rules_with_ids: + self.forward_rules_with_ids[forward_key] = { + 'student_ids': s1_ids, + 'teacher_ids': s2_ids, + 'count': 0 + } + if reverse_key not in self.reverse_rules_with_ids: + self.reverse_rules_with_ids[reverse_key] = { + 'teacher_ids': s1_ids, # Note: reversed + 'student_ids': s2_ids, # Note: reversed + 'count': 0 + } + + # Count how many times the rule was triggered + self.forward_rule_counts[forward_key] = self.forward_rule_counts.get(forward_key, 0) + 1 + self.reverse_rule_counts[reverse_key] = self.reverse_rule_counts.get(reverse_key, 0) + 1 + + # Update counts in the ID-based rules + self.forward_rules_with_ids[forward_key]['count'] += 1 + self.reverse_rules_with_ids[reverse_key]['count'] += 1 + + def translate(self, sequence, direction='forward'): + """ + Translate a sequence using current rules. + direction: 'forward' (seq1→seq2) or 'reverse' (seq2→seq1) + """ + rules = self.forward_rules if direction == 'forward' else self.reverse_rules + rule_map = {src: tgt for src, tgt in rules} + sorted_rules = sorted(rule_map.items(), key=lambda x: -len(x[0])) + + output = [] + i = 0 + while i < len(sequence): + matched = False + for src, tgt in sorted_rules: + src_len = len(src) + if tuple(sequence[i:i + src_len]) == src: + output.extend(tgt) + i += src_len + matched = True + break + if not matched: + output.append(sequence[i]) + i += 1 + return output + + def translate_via_alignment_spans(self, source_sequence, aligned_pairs, source_is_seq1=True): + """ + Translate a sequence using explicit alignment spans without reconstructing rules. + """ + translated = [] + source_idx = 0 + + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end, *_ in aligned_pairs: + current_source = s1_tokens if source_is_seq1 else s2_tokens + target = s2_tokens if source_is_seq1 else s1_tokens + + if not current_source: + # Insertion in target, not aligned in source → emit target + translated.extend(target) + else: + match_span = source_sequence[source_idx : source_idx + len(current_source)] + if match_span == current_source: + translated.extend(target) + source_idx += len(current_source) + else: + source_idx += len(current_source) # Skip to avoid infinite loop + + return translated + + def get_rules(self, direction='forward'): + return self.forward_rules if direction == 'forward' else self.reverse_rules + + def reset_rules(self): + self.forward_rules.clear() + self.reverse_rules.clear() + + def enable_rule_tracking(self): + """Enable rule tracking for future alignments.""" + self.track_rules = True + + def disable_rule_tracking(self): + """Disable rule tracking for future alignments.""" + self.track_rules = False + + def is_rule_tracking_enabled(self): + """Check if rule tracking is currently enabled.""" + return self.track_rules + + def clear_all_rules(self): + """Clear all collected rules and reset rule tracking data.""" + self.forward_rules.clear() + self.reverse_rules.clear() + + # Clear rule counts if they exist + if hasattr(self, 'forward_rule_counts'): + self.forward_rule_counts.clear() + if hasattr(self, 'reverse_rule_counts'): + self.reverse_rule_counts.clear() + + # Clear conflict tracking if it exists + if hasattr(self, 'rule_conflicts'): + self.rule_conflicts.clear() + if hasattr(self, 'rule_conflict_counts'): + self.rule_conflict_counts.clear() + + # Clear rules with IDs if they exist + if hasattr(self, 'forward_rules_with_ids'): + self.forward_rules_with_ids.clear() + if hasattr(self, 'reverse_rules_with_ids'): + self.reverse_rules_with_ids.clear() + + # Clear conflict contexts if they exist + if hasattr(self, 'conflict_contexts'): + self.conflict_contexts.clear() + + def compute_loss(self, aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, loss_type = 'chunked_ce', exact_token_match_only = False, temperature=1.0, + loss_on_non_zero_only=False, debug_verbose=False, kd_topk: int = 0, vocab_topk: int = 8192, reverse_kl: bool = False, project_teacher_logits_to_student: bool = False, + log_softmax: str = "together", token_weights=None, gold_loss: bool = False, xtoken_loss: bool = False) -> float: + ''' + Compute the loss between two sequences of tokens. + + Args: + aligned_pairs: Aligned token pairs with alignment mask (7th element indicates correctness) + student_logits: Student model logits + teacher_logits: Teacher model logits + input_ids_student: Student input token IDs + input_ids_teacher: Teacher input token IDs + loss_type: 'chunked_ce' -> compute loss on chunks of tokens, from tokenkit + 'KL' -> compute KL divergence between teacher and student logits + 'cross_entropy' -> compute cross-entropy loss + exact_token_match_only: If True, only use 1-1 token mappings that are correct according to the mask + If False, use all alignments that are correct according to the mask + temperature: Temperature for softening probability distributions (used in KL loss) + loss_on_non_zero_only: If True, computes KL divergence only on non-zero vocabulary subset + (only used in KL loss type) + project_teacher_logits_to_student: If True, project teacher logits to student space (reverse projection) + + Returns: + Computed loss value + ''' + #make sure aligned_pairs are present + if not aligned_pairs: + raise ValueError("No aligned pairs found. Please align the sequences first.") + + topk_accuracy = 0.0 + + #create list of tokenids with correct alignments using the alignment mask + #for exact_token_match_only, add constraint that it should be 1-1 mapping + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + if exact_token_match_only: + # Use mask + 1-1 mapping constraint + tokenids_with_exact_match = [ + [el for el in batch_pairs if len(el) > 6 and el[6] and len(el[0]) == 1 and len(el[1]) == 1] + for batch_pairs in aligned_pairs + ] + else: + # Use only the alignment mask (with fallback to old behavior if mask not available) + tokenids_with_exact_match = [ + [el for el in batch_pairs if len(el) > 6 and el[6]] if batch_pairs and len(batch_pairs[0]) > 6 + else [el for el in batch_pairs if el[0] == el[1]] # fallback to old behavior + for batch_pairs in aligned_pairs + ] + else: + if exact_token_match_only: + # Use mask + 1-1 mapping constraint + tokenids_with_exact_match = [[ + (s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end, *rest) + for s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end, *rest in aligned_pairs + if len(rest) > 0 and rest[0] and len(s1_elems) == 1 and len(s2_elems) == 1 + ]] + else: + # Use only the alignment mask (with fallback to old behavior if mask not available) + if aligned_pairs and len(aligned_pairs[0]) > 6: + tokenids_with_exact_match = [[ + (s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end, *rest) + for s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end, *rest in aligned_pairs + if len(rest) > 0 and rest[0] + ]] + else: + # Fallback to old behavior for backward compatibility + tokenids_with_exact_match = [[ + (s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end) + for s1_elems, s2_elems, s1_start, s1_end, s2_start, s2_end, *rest in aligned_pairs + if s1_elems == s2_elems + ]] + + #compute the loss + if loss_type == 'chunked_ce': + #from tokenkit + loss = self.compute_ce_loss(aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match, exact_token_match_only) + elif loss_type == 'cross_entropy': + # considering only correct alignments based on mask + #go over batch size dimension + if exact_token_match_only: + losses = [] + for batch_idx in range(student_logits.shape[0]): + for alignment_pair in tokenids_with_exact_match[batch_idx]: + # Extract components from alignment pair + _, _, start1, end1, _, _ = alignment_pair[:6] + if (start1 == -1 and end1 == -1) or (start1 >= input_ids_student.shape[1]): + continue #remove out of bounds indices + logits = student_logits[batch_idx, start1:end1, :] + targets = input_ids_student[batch_idx, start1+1:end1+1]#dont forget shift + losses.append(torch.nn.functional.cross_entropy(logits.view(-1, student_logits.size(-1)), targets.view(-1))) + if losses: + loss = torch.stack(losses).mean() + else: + loss = torch.tensor(0.0, device=student_logits.device, requires_grad=True) + else: + loss = torch.nn.functional.cross_entropy(student_logits[:, :-1].reshape(-1, student_logits.size(-1)), input_ids_student[:, 1:].reshape(-1)) + + elif loss_type == 'KL': + # Use ultra-fast version for maximum speed (vocab_topk < 8192) + if vocab_topk <= -1 and hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + loss, topk_accuracy = self.compute_KL_loss_ultra_fast( + aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match, + exact_token_match_only, temperature=temperature, vocab_topk=vocab_topk, use_mixed_precision=True, reverse_kl=reverse_kl + ) + else: + loss, topk_accuracy = self.compute_KL_loss_optimized( + aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match, + exact_token_match_only, temperature=temperature, loss_on_non_zero_only=loss_on_non_zero_only, debug_verbose=debug_verbose, kd_topk=kd_topk, vocab_topk=vocab_topk, reverse_kl=reverse_kl, project_teacher_logits_to_student=project_teacher_logits_to_student, log_softmax=log_softmax, token_weights=token_weights, gold_loss=gold_loss, xtoken_loss=xtoken_loss, + ) + else: + raise ValueError(f"Loss type {loss_type} not supported") + + return loss, topk_accuracy + + def compute_feature_mse_loss(self, aligned_pairs, student_features, teacher_features, exact_token_match_only=True): + """ + Compute MSE loss between student and teacher features for exactly matching tokens. + + Args: + aligned_pairs: List of alignment information for each batch + student_features: Tensor of shape (batch_size, seq_len, hidden_dim) - student hidden states + teacher_features: Tensor of shape (batch_size, seq_len, hidden_dim) - teacher hidden states + exact_token_match_only: If True, only compute loss for tokens that exactly match + + Returns: + MSE loss tensor + """ + if not aligned_pairs: + raise ValueError("No aligned pairs found. Please align the sequences first.") + + # Create list of tokenids with exact token text match between teacher and student + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + tokenids_with_exact_match = [[el for el in batch_pairs if el[0] == el[1]] for batch_pairs in aligned_pairs] + else: + tokenids_with_exact_match = [s1_elems for s1_elems, s2_elems, *_ in aligned_pairs if s1_elems == s2_elems] + + if exact_token_match_only: + # Collect features for exactly matching tokens + student_features_matched = [] + teacher_features_matched = [] + + for batch_idx in range(student_features.shape[0]): + for _, _, start1, end1, start2, end2, *_ in tokenids_with_exact_match[batch_idx]: + # Skip invalid indices + if (start1 == -1 and end1 == -1) or (start1 >= student_features.shape[1]): + continue + if (start2 == -1 and end2 == -1) or (start2 >= teacher_features.shape[1]): + continue + + # Extract features for the matching token spans + student_span_features = student_features[batch_idx, start1:end1, :] # Shape: (span_len, hidden_dim) + teacher_span_features = teacher_features[batch_idx, start2:end2, :] # Shape: (span_len, hidden_dim) + + # Handle different span lengths by taking mean pooling or truncating + if student_span_features.shape[0] != teacher_span_features.shape[0]: + # Use mean pooling to handle different span lengths + student_span_mean = student_span_features.mean(dim=0, keepdim=True) # Shape: (1, hidden_dim) + teacher_span_mean = teacher_span_features.mean(dim=0, keepdim=True) # Shape: (1, hidden_dim) + student_features_matched.append(student_span_mean) + teacher_features_matched.append(teacher_span_mean) + else: + # Same span length, add all tokens + student_features_matched.append(student_span_features) + teacher_features_matched.append(teacher_span_features) + + # If no matching tokens found, return zero loss with gradient + if not student_features_matched: + return torch.tensor(0.0, device=student_features.device, dtype=student_features.dtype, requires_grad=True) + + # Concatenate all matched features + student_features_matched = torch.cat(student_features_matched, dim=0) # Shape: (total_matched_tokens, hidden_dim) + teacher_features_matched = torch.cat(teacher_features_matched, dim=0) # Shape: (total_matched_tokens, hidden_dim) + + # Debug: Check for NaN before MSE computation + student_has_nan = torch.isnan(student_features_matched).any() + teacher_has_nan = torch.isnan(teacher_features_matched).any() + + if student_has_nan or teacher_has_nan: + print(f"DEBUG: NaN detected before MSE - student: {student_has_nan.item()}, teacher: {teacher_has_nan.item()}") + print(f"Student matched shape: {student_features_matched.shape}, Teacher matched shape: {teacher_features_matched.shape}") + if student_has_nan: + print(f"Student features stats: min={student_features_matched.min().item()}, max={student_features_matched.max().item()}") + if teacher_has_nan: + print(f"Teacher features stats: min={teacher_features_matched.min().item()}, max={teacher_features_matched.max().item()}") + # Return zero loss to avoid NaN propagation + return torch.tensor(0.0, device=student_features.device, dtype=student_features.dtype, requires_grad=True) + + # Check for extreme values that might cause NaN + if torch.isinf(student_features_matched).any() or torch.isinf(teacher_features_matched).any(): + print("DEBUG: Infinite values detected in matched features") + return torch.tensor(0.0, device=student_features.device, dtype=student_features.dtype, requires_grad=True) + + # Compute MSE loss + mse_loss = torch.nn.functional.mse_loss(student_features_matched, teacher_features_matched) + + # Debug: Check if MSE computation resulted in NaN + if torch.isnan(mse_loss): + print(f"DEBUG: MSE loss is NaN! student_matched stats: min={student_features_matched.min().item():.6f}, max={student_features_matched.max().item():.6f}") + print(f"teacher_matched stats: min={teacher_features_matched.min().item():.6f}, max={teacher_features_matched.max().item():.6f}") + print(f"Difference stats: min={(student_features_matched - teacher_features_matched).min().item():.6f}, max={(student_features_matched - teacher_features_matched).max().item():.6f}") + return torch.tensor(0.0, device=student_features.device, dtype=student_features.dtype, requires_grad=True) + + else: + # Compute MSE loss over all positions (not recommended for cross-tokenizer alignment) + # This assumes student and teacher sequences have the same length + min_seq_len = min(student_features.shape[1], teacher_features.shape[1]) + student_features_truncated = student_features[:, :min_seq_len, :] + teacher_features_truncated = teacher_features[:, :min_seq_len, :] + mse_loss = torch.nn.functional.mse_loss(student_features_truncated, teacher_features_truncated) + + return mse_loss + + def transform_logits(self, input_logits): + """ + Project student logits to teacher-vocabulary space using the binary sparse + matrix `P` (shape [student_vocab, teacher_vocab]). + The projection keeps logit semantics: any teacher token that has *no* mapping + is set to –inf so its probability is 0 after soft-max. + """ + P = self.sparse_transformation_matrix # binary CSR matrix already on GPU + if P is None: + return None + + with torch.no_grad(): + # 1. Sparse matmul in bf16/fp16 (no big fp32 tensors) + projected = TokenAligner.project_token_likelihoods_sparse( + input_logits.softmax(dim=-1), P, input_logits.device + ) + + # 2. Columns with no mapping → –inf (probability 0) + # with torch.no_grad(): + # # `column_has_data` is 1 for columns that receive at least one copy-over + # column_has_data = (P.sum(dim=0) > 0).to(projected.dtype) # shape (teacher_vocab,) + # minus_inf = -torch.finfo(projected.dtype).max + # projected = projected * column_has_data + minus_inf * (1.0 - column_has_data) + projected = torch.log(projected + 1e-8) + + return projected + + def transform_learned_matrix_instance(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Instance method version that uses instance variables. + """ + scale_trick_enabled = self.enable_scale_trick if self.enable_scale_trick is not None else False + return TokenAligner.transform_learned_matrix(x, dim, enable_scale_trick=scale_trick_enabled) + + @staticmethod + def transform_learned_matrix(x: torch.Tensor, dim: int = -1, enable_scale_trick=None) -> torch.Tensor: + """ + Compute Quite Attention over tensor x along specified dimension. + + Args: + x: Input tensor. + dim: Dimension to apply attention over (default: -1). + + Returns: + Tensor of same shape with quite attention applied. + """ + if 0: + exp_x = torch.exp(x) + denom = 1 + torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / denom + # write as a single lambda function + # return lambda x: torch.exp(x) / (1 + torch.sum(torch.exp(x), dim=dim, keepdim=True)) + else: + scale_trick_enabled = enable_scale_trick if enable_scale_trick is not None else False + if scale_trick_enabled: + #trick with last column being multiplier of 0..1, or try with c instead of 1 in qa. + scores = torch.nn.functional.softmax(x, dim=dim) + # Create a mask to zero out the last column while preserving gradients + # mask = torch.ones_like(scores) + # mask[:, -1] = 0.0 + # scores = scores * mask + # Alternative approach using sigmoid (commented out): + # scores = scores * torch.sigmoid(x[:, -1].unsqueeze(-1)) + return scores + else: + #normal softmax + return torch.nn.functional.softmax(x, dim=dim) + return torch.nn.functional.softmax(x, dim=dim) + + def compute_KL_loss(self, aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match=None, exact_token_match_only=False, temperature=0.1, loss_on_non_zero_only=False, debug_verbose=False, kd_topk: int = 0): + """ + Computes KL divergence loss between student and teacher logits. + Always uses student->teacher projection: KL(student_projected || teacher). + + Args: + aligned_pairs: List of alignment information. + student_logits: Logits from the student model. + teacher_logits: Logits from the teacher model. + input_ids_student: Input token IDs for the student. + input_ids_teacher: Input token IDs for the teacher. + tokenids_with_exact_match: Pre-filtered list of alignment pairs. + exact_token_match_only: If True, computes loss only on 1-to-1 matching tokens. + temperature: Temperature for softening probability distributions. + loss_on_non_zero_only: If True, computes KL divergence only on non-zero vocabulary subset. + + Returns: + Computed KL divergence loss tensor. + """ + + # Always use student->teacher projection: KL(student_projected || teacher) + # Project student logits to teacher's vocabulary space + # student_probs = torch.nn.functional.softmax(student_logits / temperature, dim=-1) + if tokenids_with_exact_match is None: + tokenids_with_exact_match = [ + [el for el in batch_pairs if len(el) > 6 and el[6] and len(el[0]) == 1 and len(el[1]) == 1] + for batch_pairs in aligned_pairs + ] + + + #july 26th, lets project logits not probabilities as before + + # student_logits = student_logits / temperature + student_probs = torch.nn.functional.softmax(student_logits / temperature, dim=-1) + + # Detect which format is loaded and use appropriate projection + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # Use sparse matrix projection (student→teacher) + student_logits_projected = self.project_token_likelihoods_instance( + student_probs, + None, None, None, # Not used for sparse format + teacher_logits.device, + use_sparse_format=True, + sparse_matrix=self.sparse_transformation_matrix + ) + elif hasattr(self, 'likelihood_projection_indices') and self.likelihood_projection_indices is not None: + # Use dense projection (student→teacher) + student_logits_projected = self.project_token_likelihoods_instance( + student_probs, + self.likelihood_projection_indices, + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) if self.learnable else self.likelihood_projection_matrix, + teacher_logits.shape[-1], + teacher_logits.device, + use_sparse_format=False, + # global_top_indices=self.global_top_indices + ) + else: + raise ValueError("No projection matrix loaded. Please call _load_logits_projection_map() first.") + + + + + + # Get teacher log-probabilities (target distribution) + # Ensure teacher logits match the projected vocabulary size + # projected_vocab_size = student_probs_projected.shape[-1] + # teacher_vocab_size = teacher_logits.shape[-1] + + # Debug stats for teacher_logits (original) + if debug_verbose: + print(f"DEBUG: teacher_logits (original) - shape: {teacher_logits.shape}") + print(f"DEBUG: teacher_logits (original) - min: {teacher_logits.min().item():.6f}, max: {teacher_logits.max().item():.6f}, mean: {teacher_logits.mean().item():.6f}") + print(f"DEBUG: teacher_logits (original) - has NaN: {torch.isnan(teacher_logits).any().item()}, has inf: {torch.isinf(teacher_logits).any().item()}") + + # Debug stats for projected student probabilities + # Note: 'projected_probs' is defined below, but at this point we can infer the shape from logits + print(f"DEBUG: projected student distribution (after projection) - expected vocab: {teacher_logits.shape[-1]}") + + # if projected_vocab_size != teacher_vocab_size: + # # Truncate or pad teacher logits to match projected vocabulary size + # if projected_vocab_size < teacher_vocab_size: + # # Truncate teacher logits to match projected size + # teacher_logits_matched = teacher_logits[:, :, :projected_vocab_size] + # print(f"Warning: Truncating teacher logits from {teacher_vocab_size} to {projected_vocab_size} to match projected vocabulary") + # else: + # # Pad teacher logits with very negative values (near zero probability) + # padding_size = projected_vocab_size - teacher_vocab_size + # padding = torch.full((*teacher_logits.shape[:-1], padding_size), -1e8, + # device=teacher_logits.device, dtype=teacher_logits.dtype) + # teacher_logits_matched = torch.cat([teacher_logits, padding], dim=-1) + # print(f"Warning: Padding teacher logits from {teacher_vocab_size} to {projected_vocab_size} to match projected vocabulary") + # else: + # teacher_logits_matched = teacher_logits + + teacher_logits_matched = teacher_logits + + teacher_log_probs = torch.nn.functional.log_softmax(teacher_logits_matched / temperature, dim=-1) + + # Debug stats for teacher_logits_matched and teacher_log_probs + if debug_verbose: + print(f"DEBUG: teacher_logits_matched - shape: {teacher_logits_matched.shape}") + print(f"DEBUG: teacher_logits_matched - min: {teacher_logits_matched.min().item():.6f}, max: {teacher_logits_matched.max().item():.6f}, mean: {teacher_logits_matched.mean().item():.6f}") + print(f"DEBUG: teacher_log_probs - shape: {teacher_log_probs.shape}") + print(f"DEBUG: teacher_log_probs - min: {teacher_log_probs.min().item():.6f}, max: {teacher_log_probs.max().item():.6f}, mean: {teacher_log_probs.mean().item():.6f}") + print(f"DEBUG: teacher_log_probs - has NaN: {torch.isnan(teacher_log_probs).any().item()}, has inf: {torch.isinf(teacher_log_probs).any().item()}") + + # Use student_probs_projected as P and teacher_log_probs as Q + # KL(P || Q) = KL(student_projected || teacher) + # projected_probs = torch.nn.functional.softmax(student_logits_projected, dim=-1) + projected_probs = student_logits_projected + target_log_probs = teacher_log_probs + + # Optional teacher top-k with renormalization (argument-driven) + # Only enable top-k in the exact-match path; the chunk path expects full vocab shapes + if kd_topk and exact_token_match_only and not loss_on_non_zero_only: + k = min(int(kd_topk), teacher_logits_matched.shape[-1]) + if k > 0 and k < teacher_logits_matched.shape[-1]: + # Teacher: top-k logits and renormalized log-probs over k + topk = torch.topk(teacher_logits_matched, k=k, dim=-1) + topk_indices = topk.indices + topk_logits = topk.values / temperature + target_log_probs = torch.nn.functional.log_softmax(topk_logits, dim=-1) + + # Student: gather projected probs at teacher top-k indices and renormalize over k + gathered = torch.gather(projected_probs, dim=-1, index=topk_indices) + denom = gathered.sum(dim=-1, keepdim=True).clamp_min(1e-10) + projected_probs = gathered / denom + + # Debug stats for projected_probs and target_log_probs + if debug_verbose: + print(f"DEBUG: projected_probs - min: {projected_probs.min().item():.6f}, max: {projected_probs.max().item():.6f}, mean: {projected_probs.mean().item():.6f}") + print(f"DEBUG: projected_probs - has NaN: {torch.isnan(projected_probs).any().item()}, has inf: {torch.isinf(projected_probs).any().item()}") + print(f"DEBUG: target_log_probs - min: {target_log_probs.min().item():.6f}, max: {target_log_probs.max().item():.6f}, mean: {target_log_probs.mean().item():.6f}") + print(f"DEBUG: target_log_probs - has NaN: {torch.isnan(target_log_probs).any().item()}, has inf: {torch.isinf(target_log_probs).any().item()}") + + if loss_on_non_zero_only: + if not hasattr(self, 'sparse_transformation_matrix') or self.sparse_transformation_matrix is None: + raise ValueError("loss_on_non_zero_only=True requires a sparse transformation matrix to be loaded.") + + # Cache the mask for efficiency + # For student→teacher projection, non-zero indices are in the teacher vocabulary (columns) + if not hasattr(self, '_non_zero_teacher_vocab_mask'): + with torch.no_grad(): + # Get the unique column indices from the sparse matrix, which correspond to the teacher vocabulary + non_zero_indices = self.sparse_transformation_matrix.coalesce().indices()[1].unique() + # Create a mask for the full vocabulary + mask = torch.zeros(teacher_logits_matched.shape[-1], dtype=torch.bool, device=teacher_logits_matched.device) + mask[non_zero_indices] = True + self._non_zero_teacher_vocab_mask = mask + vocab_mask = self._non_zero_teacher_vocab_mask + + # Apply mask to both projected probabilities and target logits + # Zero out probabilities for tokens not in the transformation matrix + projected_probs = projected_probs * vocab_mask.unsqueeze(0).unsqueeze(0) + # renormalize projected probabilities + projected_probs = projected_probs / projected_probs.sum(-1, keepdim=True) # didnt check it before hand + + # Zero out target logits for tokens not in the transformation matrix (before softmax) + masked_teacher_logits = teacher_logits_matched * vocab_mask.unsqueeze(0).unsqueeze(0) + # Set masked positions to very negative values so they don't contribute to softmax + masked_teacher_logits = masked_teacher_logits + (~vocab_mask).unsqueeze(0).unsqueeze(0) * (-1e9) + target_log_probs = torch.nn.functional.log_softmax(masked_teacher_logits / temperature, dim=-1) + + if exact_token_match_only: + # Create boolean masks to select only the distributions for exactly matched tokens + projected_mask = torch.zeros(projected_probs.shape[:2], dtype=torch.bool, device=student_logits.device) + target_mask = torch.zeros(target_log_probs.shape[:2], dtype=torch.bool, device=student_logits.device) + + for example_idx in range(student_logits.shape[0]): + for alignment_pair in aligned_pairs[example_idx]: + # Extract components from alignment pair + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if (start1 == -1 and end1 == -1) or (start1 >= input_ids_student.shape[1]): + continue + if (start2 == -1 and end2 == -1) or (start2 >= input_ids_teacher.shape[1]): + continue + if start1 == 0 or start2 == 0: + continue + if (end1 - start1 != end2 - start2) and (end1-start1 != 1): + continue + # print(f"s1text: {s1text}, s2text: {s2text}, start1: {start1}, end1: {end1}, start2: {start2}, end2: {end2}") + # For student→teacher projection: start1 is student (projected), start2 is teacher (target) + projected_mask[example_idx, start1-1] = True # student positions + target_mask[example_idx, start2-1] = True # teacher positions + + # Apply masks to get distributions for aligned tokens + # Select only positions where mask is True + projected_probs_masked = projected_probs[projected_mask] # Shape: (num_true_positions, vocab_size) + target_log_probs_masked = target_log_probs[target_mask] # Shape: (num_true_positions, vocab_size) + + projected_log_probs_masked = torch.log(projected_probs_masked + 1e-10) + + # If no tokens are aligned, loss is 0, but we need a tensor with grad_fn + if projected_probs_masked.numel() == 0: + return torch.tensor(0.0, device=student_logits.device, requires_grad=True) + + # Compute KL divergence on the masked distributions: KL(projected || target) + # Check if shapes match and handle gracefully + # if projected_probs_masked.shape != target_log_probs_masked.shape: + # print(f"Warning: Shape mismatch in KL loss computation - projected: {projected_probs_masked.shape}, target: {target_log_probs_masked.shape}") + # print("This should not happen after vocabulary size matching. Returning zero loss.") + # loss_kl = torch.tensor(0.0, device=student_logits.device, requires_grad=True) + # else: + # loss_kl = torch.nn.functional.kl_div(target_log_probs_masked, projected_probs_masked, reduction="batchmean", log_target=False) + loss_kl = torch.nn.functional.kl_div(projected_log_probs_masked, target_log_probs_masked, reduction="batchmean", log_target=True) + + if 1: + # for debugging + # Compute top-5 accuracy for exact token matching + with torch.no_grad(): + if projected_probs_masked.numel() > 0: + # Use masked versions for exact token matching + student_top1_masked = torch.topk(projected_probs_masked, k=min(1, projected_probs_masked.shape[-1]), dim=-1).indices + teacher_probs_masked = torch.exp(target_log_probs_masked) + teacher_top1_masked = torch.topk(teacher_probs_masked, k=min(1, teacher_probs_masked.shape[-1]), dim=-1).indices + + # Calculate overlap between top-5 predictions + matches = 0 + total = 0 + for i in range(student_top1_masked.shape[0]): + student_set = set(student_top1_masked[i].cpu().numpy()) + teacher_set = set(teacher_top1_masked[i].cpu().numpy()) + if len(student_set.intersection(teacher_set)) > 0: + matches += 1 + total += 1 + + top1_accuracy = matches / total if total > 0 else 0.0 + + else: + # Chunk-based alignment with proper averaging to handle many-to-many token mappings + max_length_projected = projected_probs.shape[1] + max_length_target = target_log_probs.shape[1] + max_n_chunks = min(max_length_projected, max_length_target) + n_examples = student_logits.shape[0] + + projected_tokens_to_chunks = torch.zeros((n_examples, max_length_projected, max_n_chunks), dtype=torch.bool).to(student_logits.device) + target_tokens_to_chunks = torch.zeros((n_examples, max_length_target, max_n_chunks), dtype=torch.bool).to(student_logits.device) + + # Build alignment masks + for example_idx in range(n_examples): + chunk_idx = 0 + for alignment_pair in aligned_pairs[example_idx]: + # Extract components from alignment pair + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1 and chunk_idx < max_n_chunks: + # For student→teacher projection: start1 is student (projected), start2 is teacher (target) + projected_tokens_to_chunks[example_idx, start1:end1, chunk_idx] = 1 # student positions + target_tokens_to_chunks[example_idx, start2:end2, chunk_idx] = 1 # teacher positions + chunk_idx += 1 + + # Compute chunk-averaged distributions + # For student (projected probabilities): average probabilities within each chunk + projected_chunk_probs = torch.bmm( + projected_tokens_to_chunks.transpose(1, 2).to(projected_probs.dtype), # (batch, max_n_chunks, max_length_projected) + projected_probs # (batch, max_length_projected, vocab_size) + ) # Result: (batch, max_n_chunks, vocab_size) + + # Normalize by number of tokens in each chunk to get proper averages + chunk_sizes_projected = projected_tokens_to_chunks.sum(dim=1, keepdim=True).float() # (batch, 1, max_n_chunks) + chunk_sizes_projected = chunk_sizes_projected.transpose(1, 2) # (batch, max_n_chunks, 1) + projected_chunk_probs = projected_chunk_probs / (chunk_sizes_projected + 1e-10) # Avoid division by zero + + # Renormalize to ensure probabilities sum to 1 (handles numerical precision errors) + projected_chunk_probs = projected_chunk_probs / (projected_chunk_probs.sum(dim=-1, keepdim=True) + 1e-10) + + # Convert projected chunk probabilities to log probabilities (will recompute after optional top-k) + projected_chunk_log_probs = torch.log(projected_chunk_probs + 1e-10) + + # Alternative: Geometric mean instead of arithmetic mean (uncomment to try) + # This computes (P1 * P2 * ... * Pn)^(1/n) for each chunk + # projected_chunk_probs_geom = torch.ones_like(projected_chunk_probs) + # for example_idx in range(n_examples): + # for chunk_idx in range(max_n_chunks): + # mask = projected_tokens_to_chunks[example_idx, :, chunk_idx] + # if mask.any(): + # chunk_tokens = projected_probs[example_idx][mask] # (num_tokens_in_chunk, vocab_size) + # chunk_product = torch.prod(chunk_tokens + 1e-10, dim=0) # Product across tokens + # chunk_geom_mean = torch.pow(chunk_product, 1.0 / mask.sum().float()) + # projected_chunk_probs_geom[example_idx, chunk_idx] = chunk_geom_mean + + # For teacher: convert logits to probabilities first, then average probabilities (consistent with student) + teacher_probs = torch.softmax(teacher_logits_matched / temperature, dim=-1) # Convert to probabilities first + target_chunk_probs = torch.bmm( + target_tokens_to_chunks.transpose(1, 2).to(teacher_probs.dtype), # (batch, max_n_chunks, max_length_target) + teacher_probs # (batch, max_length_target, vocab_size) + ) # Result: (batch, max_n_chunks, vocab_size) + + # Normalize by number of tokens in each chunk to get proper averages + chunk_sizes_target = target_tokens_to_chunks.sum(dim=1, keepdim=True).float() # (batch, 1, max_n_chunks) + chunk_sizes_target = chunk_sizes_target.transpose(1, 2) # (batch, max_n_chunks, 1) + target_chunk_probs = target_chunk_probs / (chunk_sizes_target + 1e-10) # Avoid division by zero + + # Renormalize to ensure probabilities sum to 1 (handles numerical precision errors) + target_chunk_probs = target_chunk_probs / (target_chunk_probs.sum(dim=-1, keepdim=True) + 1e-10) + + # Optional top-k over chunks (argument-driven) + if kd_topk and not loss_on_non_zero_only: + k = min(int(kd_topk), target_chunk_probs.shape[-1]) + if k > 0 and k < target_chunk_probs.shape[-1]: + topk = torch.topk(target_chunk_probs, k=k, dim=-1) + indices_k = topk.indices + target_probs_k = topk.values + projected_probs_k = torch.gather(projected_chunk_probs, dim=-1, index=indices_k) + # Renormalize over k + t_denom = target_probs_k.sum(dim=-1, keepdim=True).clamp_min(1e-10) + s_denom = projected_probs_k.sum(dim=-1, keepdim=True).clamp_min(1e-10) + target_chunk_probs = target_probs_k / t_denom + projected_chunk_probs = projected_probs_k / s_denom + # Recompute projected log-probs after slicing to keep shapes aligned + projected_chunk_log_probs = torch.log(projected_chunk_probs + 1e-10) + + # Convert target chunk probabilities to log probabilities + target_chunk_log_probs = torch.log(target_chunk_probs + 1e-10) + + # Create mask for valid chunks (chunks that have tokens from both sides) + chunk_mask = (chunk_sizes_projected.squeeze(-1) > 0) & (chunk_sizes_target.squeeze(-1) > 0) + + # Compute KL divergence: KL(projected_chunk_probs || target_chunk_log_probs) + loss_kl = torch.nn.functional.kl_div( + projected_chunk_log_probs, + target_chunk_log_probs, + reduction="none", + log_target=True + ) + + # Apply chunk mask and compute weighted average + if chunk_mask.sum() > 0: + loss_kl_weighted = (loss_kl * chunk_mask.unsqueeze(-1)).sum() / chunk_mask.sum() + loss_kl = loss_kl_weighted + else: + loss_kl = torch.tensor(0.0, device=student_logits.device, requires_grad=True) + + if 1: + # Compute top-1 accuracy for chunk-based alignment + with torch.no_grad(): + if chunk_mask.sum() > 0: + # Get top-1 predictions from chunk-averaged distributions + student_top1_indices = torch.argmax(projected_chunk_probs, dim=-1) # (batch, max_n_chunks) + teacher_top1_indices = torch.argmax(target_chunk_probs, dim=-1) # (batch, max_n_chunks) + + # Count matches only for valid chunks + matches = ((student_top1_indices == teacher_top1_indices) & chunk_mask).sum().item() + total = chunk_mask.sum().item() + + top1_accuracy = matches / total if total > 0 else 0.0 + else: + top1_accuracy = 0.0 + + # Scale loss by temperature squared + return loss_kl * (temperature**2), top1_accuracy + + def compute_projected_logits_KL_loss(self, aligned_pairs, projected_student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match=None, exact_token_match_only=True, temperature=1.0, rewrite_with_sparse_projection=False, kd_topk: int = 0): + """ + Computes KL divergence loss for student logits that have already been projected into the teacher's vocabulary space. + This function does NOT use the internal transformation_matrix by default. It assumes the projection is done externally. + It relies on alignment pairs to correctly match tokens between the student and teacher sequences, which may have different lengths. + + Args: + aligned_pairs: List of alignment information for each batch. + projected_student_logits: Student logits projected to the teacher's vocabulary space. + Shape: (batch_size, student_seq_len, teacher_vocab_size). + teacher_logits: Original teacher model logits. + Shape: (batch_size, teacher_seq_len, teacher_vocab_size). + input_ids_student: Student input token IDs. + input_ids_teacher: Teacher input token IDs. + tokenids_with_exact_match: Pre-filtered list of alignment pairs. If None, it will be computed. + exact_token_match_only: If True, computes loss only on 1-to-1 token mappings that are textually identical. + If False, uses chunk-based alignment similar to compute_KL_loss. + temperature: Temperature for softening probability distributions. + rewrite_with_sparse_projection: If True, rewrites elements in projected_student_logits using the + sparse transformation matrix for vocabulary positions that have + non-zero entries in the matrix. Requires sparse_transformation_matrix + to be loaded. + + Returns: + tuple: (Computed KL divergence loss tensor, top-1 accuracy float) + """ + if tokenids_with_exact_match is None: + if isinstance(aligned_pairs, list) and aligned_pairs and isinstance(aligned_pairs[0], list): + if exact_token_match_only: + # Use mask + 1-1 mapping constraint + tokenids_with_exact_match = [ + [el for el in batch_pairs if len(el) > 6 and el[6] and len(el[0]) == 1 and len(el[1]) == 1] + for batch_pairs in aligned_pairs + ] + else: + # Use only the alignment mask (with fallback to old behavior if mask not available) + tokenids_with_exact_match = [ + [el for el in batch_pairs if len(el) > 6 and el[6]] if batch_pairs and len(batch_pairs[0]) > 6 + else [el for el in batch_pairs if el[0] == el[1]] # fallback to old behavior + for batch_pairs in aligned_pairs + ] + else: + raise ValueError("aligned_pairs must be a list of lists (batched input).") + + # Optionally rewrite projected logits using sparse transformation matrix + if rewrite_with_sparse_projection: + if not hasattr(self, 'sparse_transformation_matrix') or self.sparse_transformation_matrix is None: + raise ValueError("rewrite_with_sparse_projection=True requires a sparse transformation matrix to be loaded.") + + # Since the sparse matrix contains only 0s and 1s, we can work directly with logits + # and avoid expensive softmax/log conversions + + # Get the sparse matrix indices for direct mapping + sparse_indices = self.sparse_transformation_matrix.coalesce().indices() # Shape: [2, num_nonzero] + student_vocab_indices = sparse_indices[0] # Student vocabulary indices (rows) + teacher_vocab_indices = sparse_indices[1] # Teacher vocabulary indices (columns) + + # Clone to avoid in-place modification + projected_student_logits = projected_student_logits.clone() + + # For binary sparse matrices (0s and 1s), directly copy teacher logits to corresponding positions + # This overwrites the projected logits with teacher logits for vocabulary positions that have + # non-zero entries in the transformation matrix + projected_student_logits[:, :, teacher_vocab_indices] = teacher_logits[:, :, teacher_vocab_indices] + + # Get probabilities from logits (after potential rewriting) + projected_student_probs = torch.nn.functional.softmax(projected_student_logits / temperature, dim=-1) + teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1) + projected_student_log_probs = torch.nn.functional.log_softmax(projected_student_logits / temperature, dim=-1) + + if exact_token_match_only: + # Create boolean masks to select distributions for exactly matched tokens + student_mask = torch.zeros(projected_student_logits.shape[:2], dtype=torch.bool, device=projected_student_logits.device) + teacher_mask = torch.zeros(teacher_logits.shape[:2], dtype=torch.bool, device=teacher_logits.device) + + for batch_idx in range(projected_student_logits.shape[0]): + if batch_idx >= len(tokenids_with_exact_match): + continue + + for _, _, start1, end1, start2, end2, *_ in tokenids_with_exact_match[batch_idx]: + # Skip invalid indices or non 1-to-1 matches + if start1 == -1 or start2 == -1 or (end1 - start1 != 1) or (end2 - start2 != 1): + continue + + if start1 == 0 or start2 == 0: + continue + + # Ensure indices are within bounds (accounting for shift) + if start1-1 >= projected_student_logits.shape[1] or start2-1 >= teacher_logits.shape[1]: + continue + if start1-1 < 0 or start2-1 < 0: + continue + #we will shift here - use logits at position t-1 to predict token at position t + student_mask[batch_idx, start1-1] = True + teacher_mask[batch_idx, start2-1] = True + + # Apply masks to get distributions for aligned tokens + student_log_probs_masked = projected_student_log_probs[student_mask] + teacher_probs_masked = teacher_probs[teacher_mask] + + # If no tokens are aligned, loss is 0 + if student_log_probs_masked.numel() == 0: + return torch.tensor(0.0, device=projected_student_logits.device, requires_grad=True), 0.0 + + # Ensure the number of matched tokens is consistent + if student_log_probs_masked.shape[0] != teacher_probs_masked.shape[0]: + # This case can indicate a bug in alignment or masking logic. + # It's safer to return 0 loss than to proceed with mismatched tensors. + return torch.tensor(0.0, device=projected_student_logits.device, requires_grad=True), 0.0 + + # Compute KL divergence on the masked distributions: KL(teacher || student) + loss_kl = torch.nn.functional.kl_div(student_log_probs_masked, teacher_probs_masked, reduction="batchmean", log_target=False) + + # Compute top-1 accuracy for exact token matching (projected logits) + with torch.no_grad(): + if student_log_probs_masked.numel() > 0: + # Convert log probabilities to probabilities for masked student predictions + student_probs_masked = torch.exp(student_log_probs_masked) + + # Get top-1 predictions for both + student_top1_masked = torch.topk(student_probs_masked, k=min(1, student_probs_masked.shape[-1]), dim=-1).indices + teacher_top1_masked = torch.topk(teacher_probs_masked, k=min(1, teacher_probs_masked.shape[-1]), dim=-1).indices + + # Calculate overlap between top-1 predictions + matches = 0 + total = 0 + for i in range(student_top1_masked.shape[0]): + student_set = set(student_top1_masked[i].cpu().numpy()) + teacher_set = set(teacher_top1_masked[i].cpu().numpy()) + if len(student_set.intersection(teacher_set)) > 0: + matches += 1 + total += 1 + + top1_accuracy = matches / total if total > 0 else 0.0 + # print(f"Top-1 accuracy (projected exact match): {top1_accuracy:.4f} ({matches}/{total})") + else: + top1_accuracy = 0.0 + # print("Top-1 accuracy (projected exact match): 0.0000 (0/0)") + + else: + # Chunk-based alignment similar to compute_KL_loss + # print("chunk-based alignment") + max_length_teacher = teacher_logits.shape[1] + max_length_student = projected_student_logits.shape[1] + max_n_chunks = min(max_length_teacher, max_length_student) + n_examples = projected_student_logits.shape[0] #batch size + + teacher_tokens_to_chunks = torch.zeros((n_examples, max_length_teacher, max_n_chunks), dtype=torch.bool).to(projected_student_logits.device) + student_tokens_to_chunks = torch.zeros((n_examples, max_length_student, max_n_chunks), dtype=torch.bool).to(projected_student_logits.device) + + # Use alignment mask to filter correct alignments + for example_idx in range(n_examples): + chunk_idx = 0 + # for alignment_pair in tokenids_with_exact_match[example_idx]: + for alignment_pair in aligned_pairs[example_idx]: + # Extract components from alignment pair + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + # if start1 == 0 or start2 == 0: + # continue + if start1 != -1 and start2 != -1: + student_tokens_to_chunks[example_idx, start1:end1, chunk_idx] = 1 + teacher_tokens_to_chunks[example_idx, start2:end2, chunk_idx] = 1 + chunk_idx += 1 + + chunk_mask = (teacher_tokens_to_chunks.sum(-2) > 0) & (student_tokens_to_chunks.sum(-2) > 0) + + if 0: + teacher_chunk_probs = torch.bmm(teacher_tokens_to_chunks.transpose(1, 2).to(teacher_probs.dtype), teacher_probs) + student_chunk_probs = torch.bmm(student_tokens_to_chunks.transpose(1, 2).to(projected_student_log_probs.dtype), projected_student_log_probs.exp()) + # or equivalently, student_tokens_to_chunks[:, 1:].sum(-2) > 0 + + # Compute KL divergence over the entire sequences + loss_kl = torch.nn.functional.kl_div(torch.log(student_chunk_probs+1e-10), torch.log(teacher_chunk_probs+1e-10), reduction="none", log_target=True) + else: + #redo in logits space + teacher_logits_chunk = torch.bmm(teacher_tokens_to_chunks.transpose(1, 2).to(teacher_logits.dtype), teacher_logits) + student_logits_chunk = torch.bmm(student_tokens_to_chunks.transpose(1, 2).to(projected_student_logits.dtype), projected_student_logits) + #do log_softmax + student_log_probs_chunk = torch.nn.functional.log_softmax(student_logits_chunk, dim=-1) + teacher_log_probs_chunk = torch.nn.functional.log_softmax(teacher_logits_chunk, dim=-1) + # Convert to probs for optional top-k slicing + teacher_chunk_probs = torch.exp(teacher_log_probs_chunk) + student_chunk_probs = torch.exp(student_log_probs_chunk) + # Optional top-k over chunks (argument-driven). Applies to chunk mode as well. + if kd_topk: + k = min(int(kd_topk), teacher_chunk_probs.shape[-1]) + if k > 0 and k < teacher_chunk_probs.shape[-1]: + topk = torch.topk(teacher_chunk_probs, k=k, dim=-1) + idx = topk.indices + teacher_probs_k = topk.values + student_probs_k = torch.gather(student_chunk_probs, dim=-1, index=idx) + # Renormalize over k + t_denom = teacher_probs_k.sum(dim=-1, keepdim=True).clamp_min(1e-10) + s_denom = student_probs_k.sum(dim=-1, keepdim=True).clamp_min(1e-10) + teacher_probs_k = teacher_probs_k / t_denom + student_probs_k = student_probs_k / s_denom + # Replace log-probs and probs with k-sliced versions + teacher_chunk_probs = teacher_probs_k + student_chunk_probs = student_probs_k + teacher_log_probs_chunk = torch.log(teacher_probs_k + 1e-10) + student_log_probs_chunk = torch.log(student_probs_k + 1e-10) + # Compute KL on (possibly) reduced distributions + loss_kl = torch.nn.functional.kl_div(student_log_probs_chunk, teacher_log_probs_chunk, reduction="none", log_target=True) + + #apply chunk mask + loss_kl_weighted = (loss_kl * chunk_mask[:,:,None]).sum() / chunk_mask.sum() + loss_kl = loss_kl_weighted + + # Compute top-1 accuracy for chunk-based alignment (projected logits) + with torch.no_grad(): + # Get top-1 predictions from projected student probabilities and teacher probabilities + student_top1_indices = torch.topk(student_chunk_probs, k=1, dim=-1).indices + teacher_top1_indices = torch.topk(teacher_chunk_probs, k=1, dim=-1).indices + + batch_size, seq_len = projected_student_probs.shape[:2] + matches = 0 + total = 0 + + for b in range(batch_size): + for t in range(seq_len): + student_set = set(student_top1_indices[b, t].cpu().numpy()) + teacher_set = set(teacher_top1_indices[b, t].cpu().numpy()) + if len(student_set.intersection(teacher_set)) > 0: + matches += 1 + total += 1 + + top1_accuracy = matches / total if total > 0 else 0.0 + # print(f"Top-1 accuracy (projected chunk-based): {top1_accuracy:.4f} ({matches}/{total})") + + # Scale loss by temperature squared + return loss_kl * (temperature**2), top1_accuracy + + def compute_ce_loss(self, aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match=None, exact_token_match_only=False): + #need to understand this function + max_length_teacher = teacher_logits.shape[1] + max_length_student = student_logits.shape[1] + max_n_chunks = min(max_length_teacher, max_length_student) + n_examples = student_logits.shape[0] #batch size + + teacher_tokens_to_chunks = torch.zeros((n_examples, max_length_teacher, max_n_chunks), dtype=torch.bool).to(student_logits.device) + student_tokens_to_chunks = torch.zeros((n_examples, max_length_student, max_n_chunks), dtype=torch.bool).to(student_logits.device) + + # Use alignment mask to filter correct alignments + for example_idx in range(n_examples): + chunk_idx = 0 + for alignment_pair in tokenids_with_exact_match[example_idx]: + # Extract components from alignment pair + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1: + teacher_tokens_to_chunks[example_idx, start2:end2, chunk_idx] = 1 + student_tokens_to_chunks[example_idx, start1:end1, chunk_idx] = 1 + chunk_idx += 1 + + teacher_logprobs = torch.log_softmax(teacher_logits, -1) + student_logprobs = torch.log_softmax(student_logits, -1) + + #shift is happening here + teacher_main_path_logprobs = torch.take_along_dim(teacher_logprobs[:, :-1], input_ids_teacher[:, 1:, None], dim=-1).squeeze(-1) + student_main_path_logprobs = torch.take_along_dim(student_logprobs[:, :-1], input_ids_student[:, 1:, None], dim=-1).squeeze(-1) + + def log1mexp(x): + """Computes log(1 - exp(x)) in a numerically stable way for x < 0.""" + # For x < log(0.5), use log1p(-exp(x)) directly + # For x >= log(0.5), use log(-expm1(x)) to avoid precision issues + log_half = -torch.log(torch.tensor(2, device=x.device)) + return torch.where(x < log_half, torch.log1p(-torch.exp(x)), torch.log(-torch.expm1(x))) + + def distance_fn(log_y_true, log_y_pred, temp=100, epsilon=1e-6): + log_y_true = (log_y_true.to(torch.float32) / temp) - epsilon + log_y_pred = (log_y_pred.to(torch.float32) / temp) - epsilon + + return -( + torch.exp(log_y_true) * log_y_pred + + (-torch.expm1(log_y_true) * log1mexp(log_y_pred)) + ) + teacher_chunk_logprobs = torch.matmul( + teacher_main_path_logprobs[:, None, :], + teacher_tokens_to_chunks[:, 1:].to(teacher_main_path_logprobs.dtype), + ) + student_chunk_logprobs = torch.matmul( + student_main_path_logprobs[:, None, :], + student_tokens_to_chunks[:, 1:].to(student_main_path_logprobs.dtype), + ) + # or equivalently, student_tokens_to_chunks[:, 1:].sum(-2) > 0 + chunk_mask = (teacher_tokens_to_chunks[:, 1:].sum(-2) > 0) & (student_tokens_to_chunks[:, 1:].sum(-2) > 0) + #is it the place to put only exact matches? + elementwise_loss = distance_fn(teacher_chunk_logprobs, student_chunk_logprobs) + + loss = (elementwise_loss * chunk_mask).mean() / chunk_mask.to(torch.float32).mean() + return loss + + def compute_KL_loss_with_checkpointing(self, aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match=None, exact_token_match_only=False, temperature=0.1, loss_on_non_zero_only=False, debug_verbose=False, kd_topk: int = 0): + """ + Memory-efficient KL loss using gradient checkpointing only. + + This is a drop-in replacement for compute_KL_loss that uses gradient checkpointing + to reduce memory usage during the backward pass. The forward computation is identical. + + Args: + Same as compute_KL_loss + + Returns: + Same as compute_KL_loss: (loss_tensor, accuracy_float) + """ + # If exact-token mode, fall back to original compute with checkpointing + if exact_token_match_only: + return torch.utils.checkpoint.checkpoint( + self.compute_KL_loss, + aligned_pairs, student_logits, teacher_logits, + input_ids_student, input_ids_teacher, tokenids_with_exact_match, + exact_token_match_only, temperature, loss_on_non_zero_only, debug_verbose, kd_topk, + use_reentrant=False + ) + + # Sequence microbatching for chunk-based KL to reduce peak memory + device = student_logits.device + batch_size = student_logits.shape[0] + student_seq_len = student_logits.shape[1] + teacher_seq_len = teacher_logits.shape[1] + teacher_vocab_size = teacher_logits.shape[-1] + + # Build alignment masks (same as in compute_KL_loss chunk path) + max_n_chunks = min(student_seq_len, teacher_seq_len) + projected_tokens_to_chunks = torch.zeros((batch_size, student_seq_len, max_n_chunks), dtype=torch.bool, device=device) + target_tokens_to_chunks = torch.zeros((batch_size, teacher_seq_len, max_n_chunks), dtype=torch.bool, device=device) + + for example_idx in range(batch_size): + chunk_idx = 0 + for alignment_pair in aligned_pairs[example_idx]: + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1 and chunk_idx < max_n_chunks: + projected_tokens_to_chunks[example_idx, start1:end1, chunk_idx] = 1 + target_tokens_to_chunks[example_idx, start2:end2, chunk_idx] = 1 + chunk_idx += 1 + + # Accumulators for chunk sums + projected_chunk_sums = torch.zeros((batch_size, max_n_chunks, teacher_vocab_size), dtype=student_logits.dtype, device=device) + target_chunk_sums = torch.zeros((batch_size, max_n_chunks, teacher_vocab_size), dtype=teacher_logits.dtype, device=device) + + # Windowed student projection and accumulation + window = 128 + # Determine projection mode + use_sparse = hasattr(self, 'sparse_transformation_matrix') and (self.sparse_transformation_matrix is not None) + has_dense = hasattr(self, 'likelihood_projection_indices') and (self.likelihood_projection_indices is not None) + + for s in range(0, student_seq_len, window): + e = min(s + window, student_seq_len) + # Student slice probs + student_probs_slice = torch.softmax(student_logits[:, s:e, :] / temperature, dim=-1) + # Project slice to teacher vocab + if use_sparse: + projected_slice = self.project_token_likelihoods_instance( + student_probs_slice, + None, None, None, + device, + use_sparse_format=True, + sparse_matrix=self.sparse_transformation_matrix + ) + elif has_dense: + projected_slice = self.project_token_likelihoods_instance( + student_probs_slice, + self.likelihood_projection_indices, + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) if getattr(self, 'learnable', False) else self.likelihood_projection_matrix, + teacher_vocab_size, + device, + use_sparse_format=False + ) + else: + raise ValueError("No projection matrix loaded. Please call _load_logits_projection_map() first.") + + mask_slice = projected_tokens_to_chunks[:, s:e, :] # (B, window_len, max_n_chunks) + # (B, max_n_chunks, window_len) @ (B, window_len, Vt) -> (B, max_n_chunks, Vt) + partial = torch.bmm(mask_slice.transpose(1, 2).to(projected_slice.dtype), projected_slice) + projected_chunk_sums += partial + + # Windowed teacher accumulation + for s in range(0, teacher_seq_len, window): + e = min(s + window, teacher_seq_len) + teacher_probs_slice = torch.softmax(teacher_logits[:, s:e, :] / temperature, dim=-1) + mask_slice = target_tokens_to_chunks[:, s:e, :] + partial = torch.bmm(mask_slice.transpose(1, 2).to(teacher_probs_slice.dtype), teacher_probs_slice) + target_chunk_sums += partial + + # Normalize by chunk sizes (mean over tokens inside chunk) + chunk_sizes_projected = projected_tokens_to_chunks.sum(dim=1, keepdim=True).float().transpose(1, 2) # (B, max_n_chunks, 1) + chunk_sizes_target = target_tokens_to_chunks.sum(dim=1, keepdim=True).float().transpose(1, 2) # (B, max_n_chunks, 1) + + projected_chunk_probs = projected_chunk_sums / (chunk_sizes_projected + 1e-10) + target_chunk_probs = target_chunk_sums / (chunk_sizes_target + 1e-10) + + # Renormalize to ensure probabilities sum to 1 + projected_chunk_probs = projected_chunk_probs / (projected_chunk_probs.sum(dim=-1, keepdim=True) + 1e-10) + target_chunk_probs = target_chunk_probs / (target_chunk_probs.sum(dim=-1, keepdim=True) + 1e-10) + + # Optional top-k slicing over chunks + if kd_topk and not loss_on_non_zero_only: + k = min(int(kd_topk), target_chunk_probs.shape[-1]) + if k > 0 and k < target_chunk_probs.shape[-1]: + topk = torch.topk(target_chunk_probs, k=k, dim=-1) + indices_k = topk.indices + target_probs_k = topk.values + projected_probs_k = torch.gather(projected_chunk_probs, dim=-1, index=indices_k) + # Renormalize over k + t_denom = target_probs_k.sum(dim=-1, keepdim=True).clamp_min(1e-10) + s_denom = projected_probs_k.sum(dim=-1, keepdim=True).clamp_min(1e-10) + target_chunk_probs = target_probs_k / t_denom + projected_chunk_probs = projected_probs_k / s_denom + + # Convert to log-probs + projected_chunk_log_probs = torch.log(projected_chunk_probs + 1e-10) + target_chunk_log_probs = torch.log(target_chunk_probs + 1e-10) + + # Valid chunk mask + chunk_mask = (chunk_sizes_projected.squeeze(-1) > 0) & (chunk_sizes_target.squeeze(-1) > 0) + + # KL divergence per chunk + loss_kl = torch.nn.functional.kl_div( + projected_chunk_log_probs, + target_chunk_log_probs, + reduction="none", + log_target=True, + ) + if chunk_mask.sum() > 0: + loss_kl = (loss_kl * chunk_mask.unsqueeze(-1)).sum() / chunk_mask.sum() + else: + loss_kl = torch.tensor(0.0, device=device, requires_grad=True) + + # Top-1 accuracy over chunks + with torch.no_grad(): + if chunk_mask.sum() > 0: + student_top1 = torch.argmax(projected_chunk_probs, dim=-1) + teacher_top1 = torch.argmax(target_chunk_probs, dim=-1) + matches = ((student_top1 == teacher_top1) & chunk_mask).sum().item() + total = chunk_mask.sum().item() + top1_accuracy = matches / total if total > 0 else 0.0 + else: + top1_accuracy = 0.0 + + return loss_kl * (temperature ** 2), top1_accuracy + + def compute_KL_loss_optimized(self, aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match=None, exact_token_match_only=False, temperature=1.0, loss_on_non_zero_only=False, debug_verbose=False, + kd_topk: int = 0, vocab_topk: int = 8192, reverse_kl: bool = False, project_teacher_logits_to_student: bool = False, log_softmax: str = "together", token_weights=None, gold_loss: bool = False, xtoken_loss: bool = False): + """ + Heavily optimized KL loss computation for large vocabularies. + + Key optimizations: + - Pre-filter vocabulary to top-K teacher tokens globally + - Fused softmax + log operations + - Reduced intermediate tensor allocations + - Early exit for empty alignments + + Args: + vocab_topk: Reduce effective vocabulary size to this many tokens based on teacher logits + project_teacher_logits_to_student: If True, project teacher logits to student space (instead of student to teacher) + gold_loss: If True, use gold loss computation (no vocab transformation for chunks, direct logit averaging) + Other args same as compute_KL_loss + """ + if not aligned_pairs or not any(aligned_pairs): + return torch.tensor(0.0, device=student_logits.device, requires_grad=True), 0.0 + + if 0: + #print aligned_pairs + # go over each entry and print the alignment pairs + for aligned_pair in aligned_pairs: + for alignment_pair in aligned_pair: + print(alignment_pair) + exit() + + device = student_logits.device + batch_size, student_seq_len, student_vocab_size = student_logits.shape + teacher_seq_len, teacher_vocab_size = teacher_logits.shape[1], teacher_logits.shape[2] + + # Gold loss path: split into exact-mapped (common) and non-exact (uncommon) vocab + if gold_loss: + # Step 1: Create exact token map from projection matrix + # Only include student tokens that have exactly one strong mapping to a teacher token + if not hasattr(self, 'likelihood_projection_indices') or self.likelihood_projection_indices is None: + raise ValueError("gold_loss requires likelihood_projection_indices to be loaded") + + projection_indices = self.likelihood_projection_indices # (student_vocab, top_k) + projection_matrix = self.transform_learned_matrix_instance(self.likelihood_projection_matrix) if getattr(self, 'learnable', False) else self.likelihood_projection_matrix + + # Find student tokens with exactly one strong mapping + # Sort projection weights for each student token to find strongest mappings + sorted_values, sorted_indices_in_topk = torch.sort(projection_matrix, dim=-1, descending=True) + + # A student token has exact mapping if: + # - First value is high (>0.9) indicating strong mapping + # - Second value is low (<0.1) indicating no other strong mappings + + if xtoken_loss: + #remove multitoken projections + #consider ones with top1 proj > 0.6 prob in the transformation matrix as exact mappings; with GOLD, anything that has <1.0 prob is considered non exact mapping for ULD loss + #avoid collisions, it's makes KL loss shoot up + has_exact_map = (sorted_values[:, 0] >= 0.6) + else: + has_exact_map = (sorted_values[:, 0] == 1.0) & (projection_indices[:, 1] == -1)# & (sorted_values[:, 1] < 0.1) + + # import pdb + # pdb.set_trace() + + # Get the actual teacher token indices for exact mappings + # projection_indices[student_idx, k] gives the teacher token for the k-th strongest mapping + student_indices_with_exact_map = torch.where(has_exact_map)[0] + teacher_indices_for_exact_map = projection_indices[student_indices_with_exact_map, sorted_indices_in_topk[student_indices_with_exact_map, 0]] + + # Create mapping dictionaries for quick lookup + student_to_teacher_exact_map = {} + teacher_to_student_exact_map = {} + teacher_collision_count = 0 + teacher_collisions = [] # Track which teacher tokens have multiple student mappings + + # for s_idx, t_idx in zip(student_indices_with_exact_map.tolist(), teacher_indices_for_exact_map.tolist()): + # # Only keep if teacher index is valid + # if 0 <= t_idx < teacher_vocab_size: + # if t_idx not in teacher_to_student_exact_map:# or xtoken_loss: + # # New mapping + # student_to_teacher_exact_map[s_idx] = t_idx + # teacher_to_student_exact_map[t_idx] = s_idx + # else: + # # Collision: teacher token already mapped to different student token + # teacher_collision_count += 1 + # existing_s_idx = teacher_to_student_exact_map[t_idx] + # teacher_collisions.append((t_idx, existing_s_idx, s_idx)) + + for s_idx, t_idx in zip(student_indices_with_exact_map.tolist(), teacher_indices_for_exact_map.tolist()): + # Only keep if teacher index is valid + if 0 <= t_idx < teacher_vocab_size: + if t_idx not in teacher_to_student_exact_map or xtoken_loss: + # New mapping + + if t_idx in teacher_to_student_exact_map: + prev_student_token = teacher_to_student_exact_map[t_idx] + prev_prob = sorted_values[prev_student_token, 0] + + if prev_prob >= sorted_values[s_idx, 0]: + # print(f"Skipping: prev_prob={prev_prob} > new_prob={sorted_values[s_idx, 0]}") + continue + else: + del student_to_teacher_exact_map[prev_student_token] + # print(f"replacing student token {prev_student_token} {prev_prob} with {s_idx} {sorted_values[s_idx, 0]}") + + student_to_teacher_exact_map[s_idx] = t_idx + teacher_to_student_exact_map[t_idx] = s_idx + else: + # Collision: teacher token already mapped to different student token + teacher_collision_count += 1 + existing_s_idx = teacher_to_student_exact_map[t_idx] + teacher_collisions.append((t_idx, existing_s_idx, s_idx)) + + # # Print collision diagnostics + # if teacher_collision_count > 0: + # print(f"⚠️ Teacher token collision warning: {teacher_collision_count} student tokens tried to map to already-mapped teacher tokens") + # if len(teacher_collisions) <= 10: + # # Print all collisions if there are few + # for t_idx, existing_s, new_s in teacher_collisions: + # print(f" Teacher token {t_idx} already mapped to student {existing_s}, skipping student {new_s}") + # else: + # # Print first 5 and last 5 if there are many + # print(f" Showing first 5 and last 5 collisions:") + # for t_idx, existing_s, new_s in teacher_collisions[:5]: + # print(f" Teacher token {t_idx} already mapped to student {existing_s}, skipping student {new_s}") + # print(f" ... ({len(teacher_collisions) - 10} more collisions) ...") + # for t_idx, existing_s, new_s in teacher_collisions[-5:]: + # print(f" Teacher token {t_idx} already mapped to student {existing_s}, skipping student {new_s}") + + # Step 2: Split indices into common (exact match) and uncommon (no exact match) + common_student_indices = sorted(student_to_teacher_exact_map.keys()) + common_teacher_indices = [student_to_teacher_exact_map[s] for s in common_student_indices] + + all_student_indices = set(range(student_vocab_size)) + all_teacher_indices = set(range(teacher_vocab_size)) + uncommon_student_indices = sorted(all_student_indices - set(common_student_indices)) + uncommon_teacher_indices = sorted(all_teacher_indices - set(common_teacher_indices)) + + # print(f"Gold loss: {len(common_student_indices)} exact token mappings, " + # f"{len(uncommon_student_indices)} uncommon student tokens, " + # f"{len(uncommon_teacher_indices)} uncommon teacher tokens") + + # Step 3: Compute loss using chunk-based masking + # Build chunk masks for all alignments (not just exact 1-to-1) + max_n_chunks = min(student_seq_len, teacher_seq_len) + + student_chunk_mask = torch.zeros((batch_size, student_seq_len, max_n_chunks), + dtype=torch.bool, device=device) + teacher_chunk_mask = torch.zeros((batch_size, teacher_seq_len, max_n_chunks), + dtype=torch.bool, device=device) + + # Fill chunk masks from alignment pairs + for batch_idx in range(batch_size): + for chunk_idx, alignment_pair in enumerate(aligned_pairs[batch_idx][:max_n_chunks]): + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1: + student_chunk_mask[batch_idx, start1:end1, chunk_idx] = True + teacher_chunk_mask[batch_idx, start2:end2, chunk_idx] = True + + # Compute log_softmax on original logits BEFORE averaging + student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1) + teacher_log_probs = torch.log_softmax(teacher_logits / temperature, dim=-1) + + # Average log probabilities within chunks for FULL vocabularies + student_chunk_log_probs_full = torch.bmm( + student_chunk_mask.transpose(1, 2).to(student_log_probs.dtype), + student_log_probs + ) # (batch, max_n_chunks, student_vocab_size) + + teacher_chunk_log_probs_full = torch.bmm( + teacher_chunk_mask.transpose(1, 2).to(teacher_log_probs.dtype), + teacher_log_probs + ) # (batch, max_n_chunks, teacher_vocab_size) + + # Normalize by chunk sizes + student_chunk_sizes = student_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) # (batch, max_n_chunks, 1) + teacher_chunk_sizes = teacher_chunk_mask.sum(dim=1, keepdim=True).float().transpose(1, 2) # (batch, max_n_chunks, 1) + + student_chunk_log_probs_full = student_chunk_log_probs_full / (student_chunk_sizes + 1e-10) + teacher_chunk_log_probs_full = teacher_chunk_log_probs_full / (teacher_chunk_sizes + 1e-10) + + # Valid chunk mask + chunk_mask = (student_chunk_sizes.squeeze(-1) > 0) & (teacher_chunk_sizes.squeeze(-1) > 0) + + if not chunk_mask.any(): + return torch.tensor(0.0, device=device, requires_grad=True), 0.0 + + # Now split chunk-averaged log probs into common and uncommon vocab + # Extract common and uncommon from chunk-averaged log probs + if len(common_student_indices) > 0: + common_student_indices_tensor = torch.tensor(common_student_indices, device=device) + common_teacher_indices_tensor = torch.tensor(common_teacher_indices, device=device) + + student_chunk_common_log_probs = student_chunk_log_probs_full[:, :, common_student_indices_tensor] # (B, chunks, num_common) + teacher_chunk_common_log_probs = teacher_chunk_log_probs_full[:, :, common_teacher_indices_tensor] # (B, chunks, num_common) + else: + student_chunk_common_log_probs = torch.empty(batch_size, max_n_chunks, 0, device=device) + teacher_chunk_common_log_probs = torch.empty(batch_size, max_n_chunks, 0, device=device) + + if len(uncommon_student_indices) > 0: + uncommon_student_indices_tensor = torch.tensor(uncommon_student_indices, device=device) + student_chunk_uncommon_log_probs = student_chunk_log_probs_full[:, :, uncommon_student_indices_tensor] # (B, chunks, num_uncommon_s) + else: + student_chunk_uncommon_log_probs = torch.empty(batch_size, max_n_chunks, 0, device=device) + + if len(uncommon_teacher_indices) > 0: + uncommon_teacher_indices_tensor = torch.tensor(uncommon_teacher_indices, device=device) + teacher_chunk_uncommon_log_probs = teacher_chunk_log_probs_full[:, :, uncommon_teacher_indices_tensor] # (B, chunks, num_uncommon_t) + else: + teacher_chunk_uncommon_log_probs = torch.empty(batch_size, max_n_chunks, 0, device=device) + + # Part 1: KL loss on common (aligned) vocab - using pre-computed log probs + loss_kl_common = torch.tensor(0.0, device=device, requires_grad=True) + if student_chunk_common_log_probs.shape[-1] > 0: + # Compute KL divergence per chunk using pre-computed log probs + if not reverse_kl: + loss_kl_per_elem = torch.nn.functional.kl_div( + student_chunk_common_log_probs, teacher_chunk_common_log_probs, + reduction="none", log_target=True + ) + else: + loss_kl_per_elem = torch.nn.functional.kl_div( + teacher_chunk_common_log_probs, student_chunk_common_log_probs, + reduction="none", log_target=True + ) + + # Sum across vocab dimension + # print(f"student {student_chunk_common_log_probs} teahcer {teacher_chunk_common_log_probs}") + # import pdb + # pdb.set_trace() + loss_kl_per_chunk = loss_kl_per_elem.sum(dim=-1) # (batch, max_n_chunks) + + # Mask invalid chunks + loss_kl_per_chunk = loss_kl_per_chunk * chunk_mask + + if chunk_mask.sum() > 0: + if token_weights is not None: + # Map chunk losses to teacher token positions + loss_kl_per_teacher_token = torch.bmm( + teacher_chunk_mask.to(loss_kl_per_chunk.dtype), + loss_kl_per_chunk.unsqueeze(-1) + ).squeeze(-1) + + weighted_loss_per_token = loss_kl_per_teacher_token * token_weights + + valid_teacher_sizes = teacher_chunk_sizes.squeeze(-1) * chunk_mask + total_teacher_token_participations = valid_teacher_sizes.sum() + if total_teacher_token_participations > 0: + loss_kl_common = weighted_loss_per_token.sum() / total_teacher_token_participations + else: + loss_kl_common = loss_kl_per_chunk.sum() / chunk_mask.sum() + # pdb.set_trace() + + # Part 2: L1 loss on uncommon (unaligned) vocab - sort chunk-averaged probabilities + loss_l1_uncommon = torch.tensor(0.0, device=device, requires_grad=True) + # import pdb + # pdb.set_trace() + if student_chunk_uncommon_log_probs.shape[-1] > 0 or teacher_chunk_uncommon_log_probs.shape[-1] > 0: + # import pdb + # pdb.set_trace() + # Get valid chunks only + student_uncommon_valid = student_chunk_uncommon_log_probs[chunk_mask] # (num_valid_chunks, num_uncommon_s) + teacher_uncommon_valid = teacher_chunk_uncommon_log_probs[chunk_mask] # (num_valid_chunks, num_uncommon_t) + + if student_uncommon_valid.shape[0] > 0: + # Convert log probabilities to probabilities using exp - use in-place operations where possible + with torch.no_grad(): + # Use topk instead of full sort to reduce memory - only need sorted values, not indices + # Limit the vocab size for uncommon distributions to prevent OOM + max_uncommon_vocab = min( + student_uncommon_valid.shape[-1], + teacher_uncommon_valid.shape[-1], + 8192 # Cap at reasonable size to prevent OOM + ) + + if max_uncommon_vocab > 0: + student_uncommon_probs = torch.exp(student_uncommon_valid) + teacher_uncommon_probs = torch.exp(teacher_uncommon_valid) + + # Use topk for memory efficiency - we only need the top probabilities + # topk is much more memory efficient than full sort + if student_uncommon_probs.shape[-1] > max_uncommon_vocab: + student_uncommon_sorted, _ = torch.topk(student_uncommon_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + student_uncommon_sorted = torch.sort(student_uncommon_probs, dim=-1, descending=True)[0] + + if teacher_uncommon_probs.shape[-1] > max_uncommon_vocab: + teacher_uncommon_sorted, _ = torch.topk(teacher_uncommon_probs, k=max_uncommon_vocab, dim=-1, largest=True) + else: + teacher_uncommon_sorted = torch.sort(teacher_uncommon_probs, dim=-1, descending=True)[0] + + # Free intermediate tensors immediately + del student_uncommon_probs, teacher_uncommon_probs + + # Take minimum length for comparison + min_uncommon_len = min(student_uncommon_sorted.shape[-1], teacher_uncommon_sorted.shape[-1]) + if min_uncommon_len > 0: + student_uncommon_sorted = student_uncommon_sorted[:, :min_uncommon_len] + teacher_uncommon_sorted = teacher_uncommon_sorted[:, :min_uncommon_len] + + # Compute L1 loss on sorted uncommon probabilities + # print(f"ULD student {student_uncommon_sorted} teacher {teacher_uncommon_sorted}") + loss_l1_per_chunk = torch.nn.functional.l1_loss( + student_uncommon_sorted, teacher_uncommon_sorted, reduction='none' + ).sum(dim=-1) # Sum over vocab dimension + + # Free sorted tensors immediately after computing loss + del student_uncommon_sorted, teacher_uncommon_sorted + + # Apply token weights if provided + if token_weights is not None: + # Expand chunk mask to get chunk indices + chunk_indices = torch.nonzero(chunk_mask, as_tuple=False) # (num_valid_chunks, 2) - [batch_idx, chunk_idx] + + # Map chunks back to teacher tokens for weighting + weighted_l1_per_chunk = torch.zeros_like(loss_l1_per_chunk) + for valid_idx, (batch_idx, chunk_idx) in enumerate(chunk_indices): + # Get teacher tokens participating in this chunk + teacher_tokens_in_chunk = teacher_chunk_mask[batch_idx, :, chunk_idx] + if teacher_tokens_in_chunk.any(): + # Average token weights for tokens in this chunk + chunk_weight = token_weights[batch_idx, teacher_tokens_in_chunk].mean() + weighted_l1_per_chunk[valid_idx] = loss_l1_per_chunk[valid_idx] * chunk_weight + + loss_l1_uncommon = weighted_l1_per_chunk.mean() + del weighted_l1_per_chunk, chunk_indices + else: + loss_l1_uncommon = loss_l1_per_chunk.mean() + # pdb.set_trace() + + del loss_l1_per_chunk + + # Combine losses + loss_total = loss_kl_common + loss_l1_uncommon + # print(f"loss_kl_common: {loss_kl_common}, loss_l1_uncommon: {loss_l1_uncommon}") + + # Free large tensors before accuracy computation to ensure memory is available + # These are no longer needed for the loss computation + del student_chunk_log_probs_full, teacher_chunk_log_probs_full + del student_chunk_mask, teacher_chunk_mask + if len(uncommon_student_indices) > 0: + del student_chunk_uncommon_log_probs + if len(uncommon_teacher_indices) > 0: + del teacher_chunk_uncommon_log_probs + + # Accuracy computation on common vocab - using pre-computed log probs + # MEMORY OPTIMIZED: argmax works directly on log probs without needing exp() + with torch.no_grad(): + if student_chunk_common_log_probs.shape[-1] > 0 and chunk_mask.any(): + # Get predictions for valid chunks BEFORE exp to save memory + # argmax is invariant to monotonic transformations, so argmax(log_probs) == argmax(probs) + student_chunk_log_probs_valid = student_chunk_common_log_probs[chunk_mask] + teacher_chunk_log_probs_valid = teacher_chunk_common_log_probs[chunk_mask] + + # Compute argmax directly on log probabilities (saves massive memory) + student_top1 = student_chunk_log_probs_valid.argmax(dim=-1) + teacher_top1 = teacher_chunk_log_probs_valid.argmax(dim=-1) + matches = (student_top1 == teacher_top1).sum().item() + top1_accuracy = matches / chunk_mask.sum().item() + + # Clean up accuracy computation tensors + del student_chunk_log_probs_valid, teacher_chunk_log_probs_valid + del student_top1, teacher_top1 + else: + top1_accuracy = 0.0 + + # Clean up remaining tensors + del chunk_mask + if len(common_student_indices) > 0: + del student_chunk_common_log_probs, teacher_chunk_common_log_probs + + return loss_total * (temperature ** 2), top1_accuracy + + if project_teacher_logits_to_student: + # REVERSE PROJECTION: Teacher → Student + # print(f"Using REVERSE projection: Teacher → Student vocabulary space") + + # Step 1: Project teacher_logits (via probs) to student space + if log_softmax == "separate": + teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1) + else: + teacher_probs = teacher_logits#torch.softmax(teacher_logits / temperature, dim=-1) + + if hasattr(self, 'reverse_sparse_transformation_matrix') and self.reverse_sparse_transformation_matrix is not None: + # print(f"Using REVERSE sparse matrix projection for teacher_probs") + projected_teacher_probs_full = self.project_token_likelihoods_instance( + teacher_probs, None, None, None, device, + use_sparse_format=True, sparse_matrix=self.reverse_sparse_transformation_matrix + ) + # projected_teacher_probs_full is now in student space, full vocab (B, T, student_vocab_size) + + elif hasattr(self, 'reverse_likelihood_projection_indices') and self.reverse_likelihood_projection_indices is not None: + # print(f"Using REVERSE dense matrix projection for teacher_probs") + reverse_matrix = self.reverse_likelihood_projection_matrix + if getattr(self, 'learnable', False): + reverse_matrix = self.transform_learned_matrix_instance(reverse_matrix) + # print(f"reverse_matrix: {reverse_matrix}") + + # print(f"reverse_likelihood_projection_indices: {self.reverse_likelihood_projection_indices}") + projected_teacher_probs_full = self.project_token_likelihoods_instance( + teacher_probs, + self.reverse_likelihood_projection_indices, + reverse_matrix, + student_vocab_size, device, use_sparse_format=False + ) + # projected_teacher_probs_full is now in student space, full vocab (B, T, student_vocab_size) + else: + raise ValueError("Reverse projection matrices not found. Please call create_reverse_projection_matrix() first.") + + # Step 2: Compute global_top_indices based on projected teacher probs (in student vocab space) + # Use max probability per vocab position to find important student vocab tokens + with torch.no_grad(): + if vocab_topk == 0 or vocab_topk >= student_vocab_size: + # Use all vocabulary tokens (no reduction) + global_top_indices = torch.arange(student_vocab_size, device=device) + else: + # Get globally most important STUDENT tokens based on projected teacher probs + projected_teacher_flat = projected_teacher_probs_full.view(-1, student_vocab_size) + global_teacher_importance = projected_teacher_flat.max(dim=0)[0] # Max prob per vocab token + _, global_top_indices = torch.topk(global_teacher_importance, + k=min(vocab_topk, student_vocab_size), + dim=-1) + global_top_indices = global_top_indices.sort()[0] # Keep sorted for efficiency + + # Step 3: Apply softmax on student_logits + student_probs = torch.softmax(student_logits / temperature, dim=-1) + + # Step 4: Slice both distributions with global_top_indices + if log_softmax == "together": + projected_teacher_probs_reduced = torch.log_softmax(projected_teacher_probs_full / temperature, dim=-1)[:, :, global_top_indices] # (B, T, vocab_topk) + else: + projected_teacher_probs_reduced = projected_teacher_probs_full[:, :, global_top_indices] # (B, T, vocab_topk) + + student_probs_reduced = student_probs[:, :, global_top_indices] # (B, S, vocab_topk) + + # Step 5: Apply log to get target_log_probs (projected probs are already in probability space) + # For consistency with the rest of the code, set projected_probs to student and target to teacher + if log_softmax == "separate": + target_log_probs = torch.log(projected_teacher_probs_reduced + 1e-10) + else: + target_log_probs = projected_teacher_probs_reduced#torch.log(projected_teacher_probs_reduced + 1e-10) # Log of projected teacher probs + projected_probs = student_probs_reduced # Student probs (sliced) + + else: + # FORWARD PROJECTION: Student → Teacher (original behavior) + # Step 1: Global vocabulary filtering (major speedup) + with torch.no_grad(): + if vocab_topk == 0 or vocab_topk >= teacher_vocab_size: + # Use all vocabulary tokens (no reduction) + global_top_indices = torch.arange(teacher_vocab_size, device=device) + else: + # Get globally most important teacher tokens across all positions + teacher_flat = teacher_logits.view(-1, teacher_vocab_size) + global_teacher_importance = teacher_flat.max(dim=0)[0] # Max logit per vocab token + _, global_top_indices = torch.topk(global_teacher_importance, + k=min(vocab_topk, teacher_vocab_size), + dim=-1) + global_top_indices = global_top_indices.sort()[0] # Keep sorted for efficiency + + # Step 2: Project student to reduced teacher vocabulary + student_probs = torch.softmax(student_logits / temperature, dim=-1) + + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # print(f"Using sparse matrix projection for student_probs") + projected_probs_full = self.project_token_likelihoods_instance( + student_probs, None, None, None, device, + use_sparse_format=True, sparse_matrix=self.sparse_transformation_matrix + ) + projected_probs = projected_probs_full[:, :, global_top_indices] # (B, S, vocab_topk) + + else: + projected_probs_full = self.project_token_likelihoods_instance( + student_probs, + self.likelihood_projection_indices, + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) if getattr(self, 'learnable', False) else self.likelihood_projection_matrix, + teacher_vocab_size, device, use_sparse_format=False, + ) + projected_probs = projected_probs_full[:, :, global_top_indices] # (B, S, vocab_topk) + + # Step 3: Slice to reduced vocabulary + teacher_logits_reduced = teacher_logits[:, :, global_top_indices] # (B, T, vocab_topk) + + # Step 4: Efficient target log-probabilities (fused softmax+log) + # print(f"teacher top 50 max probs after topk: {torch.sort(torch.softmax(teacher_logits_reduced, dim=-1), descending=True)[0][:, :50]}") + target_log_probs = torch.log_softmax(teacher_logits_reduced / temperature, dim=-1) + + if exact_token_match_only: + # Optimized exact matching with reduced vocab + student_mask = torch.zeros(batch_size, student_seq_len, dtype=torch.bool, device=device) + teacher_mask = torch.zeros(batch_size, teacher_seq_len, dtype=torch.bool, device=device) + + for batch_idx in range(batch_size): + for alignment_pair in aligned_pairs[batch_idx]: + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if (start1 > 0 and start2 > 0 and + end1 - start1 == 1 and end2 - start2 == 1 and + start1-1 < student_seq_len and start2-1 < teacher_seq_len): + student_mask[batch_idx, start1-1] = True + teacher_mask[batch_idx, start2-1] = True + # print(f"student_mask: {student_mask}") + # print(f"teacher_mask: {teacher_mask}") + # print(target_log_probs) + + if not student_mask.any(): + return torch.tensor(0.0, device=device, requires_grad=True), 0.0 + + projected_probs_masked = projected_probs[student_mask] + target_log_probs_masked = target_log_probs[teacher_mask] + + # print(f"projected_probs_masked.shape: {projected_probs_masked.shape}") + # print(f"target_log_probs_masked.shape: {target_log_probs_masked.shape}") + # exit() + # Fused log + KL computation + projected_log_probs_masked = torch.log(projected_probs_masked + 1e-10) + if not reverse_kl: + loss_kl_per_token = torch.nn.functional.kl_div( + projected_log_probs_masked, target_log_probs_masked, + reduction="none", log_target=True + ) + else: + #print("Computing reverse KL1") + loss_kl_per_token = torch.nn.functional.kl_div( + target_log_probs_masked, projected_log_probs_masked, + reduction="none", log_target=True + ) + + # Sum across vocab dimension: (num_matched_tokens, vocab_topk) -> (num_matched_tokens,) + loss_kl_per_token = loss_kl_per_token.sum(dim=-1) + + # Apply token weights if provided + if token_weights is not None: + # token_weights are based on teacher tokens, so use teacher_mask + # token_weights shape: (batch_size, teacher_seq_len) + token_weights_masked = token_weights[teacher_mask] # (num_matched_tokens,) + weighted_loss_per_token = loss_kl_per_token * token_weights_masked + # Normalize by number of tokens to ensure comparable loss magnitudes + # while weights still control relative contribution per token + num_matched_tokens = teacher_mask.sum() + if num_matched_tokens > 0: + loss_kl = weighted_loss_per_token.sum() / num_matched_tokens + else: + loss_kl = torch.tensor(0.0, device=device, requires_grad=True) + else: + # Regular batchmean reduction + loss_kl = loss_kl_per_token.mean() + + # Fast accuracy computation + with torch.no_grad(): + matches = (projected_probs_masked.argmax(dim=-1) == + torch.exp(target_log_probs_masked).argmax(dim=-1)).sum().item() + top1_accuracy = matches / projected_probs_masked.shape[0] + + else: + # Chunk-based with reduced vocabulary - similar to original but on smaller vocab + max_n_chunks = min(student_seq_len, teacher_seq_len) + + # Pre-allocate masks (more memory efficient) + proj_mask = torch.zeros((batch_size, student_seq_len, max_n_chunks), + dtype=torch.bool, device=device) + tgt_mask = torch.zeros((batch_size, teacher_seq_len, max_n_chunks), + dtype=torch.bool, device=device) + + # Fill masks efficiently + for batch_idx in range(batch_size): + for chunk_idx, alignment_pair in enumerate(aligned_pairs[batch_idx][:max_n_chunks]): + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1: + proj_mask[batch_idx, start1:end1, chunk_idx] = True + tgt_mask[batch_idx, start2:end2, chunk_idx] = True + + # Efficient chunk averaging using bmm + proj_chunks = torch.bmm(proj_mask.transpose(1,2).to(projected_probs.dtype), projected_probs) + tgt_log_chunks = torch.bmm(tgt_mask.transpose(1,2).to(target_log_probs.dtype), target_log_probs) + + # Normalize by chunk sizes + proj_sizes = proj_mask.sum(dim=1, keepdim=True).transpose(1,2) + tgt_sizes = tgt_mask.sum(dim=1, keepdim=True).transpose(1,2) + + proj_chunks = proj_chunks / (proj_sizes + 1e-10) + tgt_log_chunks = tgt_log_chunks / (tgt_sizes + 1e-10) + + # Renormalize and compute loss + proj_chunks = proj_chunks / (proj_chunks.sum(dim=-1, keepdim=True) + 1e-10) + proj_log_chunks = torch.log(proj_chunks + 1e-10) + + chunk_mask = (proj_sizes.squeeze(-1) > 0) & (tgt_sizes.squeeze(-1) > 0) + + if not reverse_kl: + loss_kl_per_elem = torch.nn.functional.kl_div( + proj_log_chunks, tgt_log_chunks, reduction="none", log_target=True + ) + else: + #print("Computing reverse KL2") + loss_kl_per_elem = torch.nn.functional.kl_div( + tgt_log_chunks, proj_log_chunks, reduction="none", log_target=True + ) + + # Sum across vocab dimension: (batch_size, max_n_chunks, vocab_topk) -> (batch_size, max_n_chunks) + loss_kl_per_chunk = loss_kl_per_elem.sum(dim=-1) + + # Mask invalid chunks + loss_kl_per_chunk = loss_kl_per_chunk * chunk_mask + + if chunk_mask.sum() > 0: + if token_weights is not None: + # Map chunk losses back to TEACHER token positions using the teacher chunk mask + # token_weights are based on teacher tokens, shape: (batch_size, teacher_seq_len) + # tgt_mask shape: (batch_size, teacher_seq_len, max_n_chunks) + # loss_kl_per_chunk shape: (batch_size, max_n_chunks) + + # For each teacher token, accumulate loss from all chunks it participates in + # (batch_size, teacher_seq_len, max_n_chunks) @ (batch_size, max_n_chunks, 1) -> (batch_size, teacher_seq_len, 1) + loss_kl_per_teacher_token = torch.bmm( + tgt_mask.to(loss_kl_per_chunk.dtype), # (batch_size, teacher_seq_len, max_n_chunks) + loss_kl_per_chunk.unsqueeze(-1) # (batch_size, max_n_chunks, 1) + ).squeeze(-1) # -> (batch_size, teacher_seq_len) + + # Weight the loss per teacher token + weighted_loss_per_token = loss_kl_per_teacher_token * token_weights + + # Sum over teacher tokens and normalize + # Normalize by total teacher token participations in VALID chunks only + # comment from sharath: how many teacher tokens present in this chunk as you go from chunk to teacher space with the BMM + # tgt_sizes shape: (batch_size, max_n_chunks, 1), chunk_mask shape: (batch_size, max_n_chunks) + valid_tgt_sizes = tgt_sizes.squeeze(-1) * chunk_mask # (batch_size, max_n_chunks) + total_teacher_token_participations = valid_tgt_sizes.sum() + if total_teacher_token_participations > 0: + # No need to mask: tokens only in invalid chunks already have loss=0 + loss_kl = weighted_loss_per_token.sum() / total_teacher_token_participations + else: + loss_kl = torch.tensor(0.0, device=device, requires_grad=True) + else: + # Regular reduction by number of valid chunks + loss_kl = loss_kl_per_chunk.sum() / chunk_mask.sum() + else: + loss_kl = torch.tensor(0.0, device=device, requires_grad=True) + # Accuracy computation + with torch.no_grad(): + if chunk_mask.sum() > 0: + proj_top1 = proj_chunks.argmax(dim=-1) + tgt_top1 = torch.exp(tgt_log_chunks).argmax(dim=-1) + matches = ((proj_top1 == tgt_top1) & chunk_mask).sum().item() + top1_accuracy = matches / chunk_mask.sum().item() + else: + top1_accuracy = 0.0 + + return loss_kl * (temperature ** 2), top1_accuracy + + def compute_KL_loss_ultra_fast(self, aligned_pairs, student_logits, teacher_logits, input_ids_student, input_ids_teacher, tokenids_with_exact_match=None, exact_token_match_only=False, temperature=1.0, vocab_topk: int = 4096, use_mixed_precision=True, reverse_kl: bool = False): + """ + Ultra-fast KL loss with maximum optimizations for production use. + + Key optimizations: + - Aggressive vocabulary pruning (4k default vs 128k) + - In-place operations where possible + - Pre-allocated tensor reuse + - Fused softmax-log operations + - Mixed precision (fp16/bf16) for intermediate computations + - Minimal tensor copying + """ + if not aligned_pairs or not any(aligned_pairs): + return torch.tensor(0.0, device=student_logits.device, requires_grad=True), 0.0 + + device = student_logits.device + batch_size, student_seq_len = student_logits.shape[:2] + teacher_seq_len, teacher_vocab_size = teacher_logits.shape[1], teacher_logits.shape[2] + + # Cache key for vocab filtering to avoid recomputation + cache_key = (teacher_logits.shape, temperature, vocab_topk) + + if not hasattr(self, '_vocab_cache') or self._vocab_cache.get('key') != cache_key: + with torch.no_grad(): + # More aggressive vocabulary filtering - use mean instead of max for better coverage + teacher_flat = teacher_logits.view(-1, teacher_vocab_size) + # Combine max and mean for better token selection + global_importance = 0.7 * teacher_flat.max(dim=0)[0] + 0.3 * teacher_flat.mean(dim=0) + _, top_indices = torch.topk(global_importance, k=min(vocab_topk, teacher_vocab_size)) + top_indices = top_indices.sort()[0] + + # Cache the indices and create index mapping for fast lookup + self._vocab_cache = { + 'key': cache_key, + 'indices': top_indices, + 'inv_mapping': torch.full((teacher_vocab_size,), -1, dtype=torch.long, device=device) + } + self._vocab_cache['inv_mapping'][top_indices] = torch.arange(len(top_indices), device=device) + + top_indices = self._vocab_cache['indices'] + + # Use mixed precision for intermediate computations if available + compute_dtype = torch.float16 if use_mixed_precision and torch.cuda.is_available() else student_logits.dtype + + # Step 1: Slice vocabularies early to reduce all subsequent operations + teacher_logits_reduced = teacher_logits[:, :, top_indices] # (B, T, vocab_topk) + + # Step 2: Project student efficiently with ultra-fast projection + student_probs = torch.softmax(student_logits / temperature, dim=-1) + + if hasattr(self, 'sparse_transformation_matrix') and self.sparse_transformation_matrix is not None: + # Use ultra-fast projection with direct vocabulary slicing + projected_probs = self.project_token_likelihoods_ultra_fast( + student_probs, + sparse_matrix=self.sparse_transformation_matrix, + target_vocab_reduced_indices=top_indices + ).to(compute_dtype) + else: + # Fallback to regular projection + slicing + projected_probs_full = self.project_token_likelihoods_instance( + student_probs, + self.likelihood_projection_indices, + self.transform_learned_matrix_instance(self.likelihood_projection_matrix) if getattr(self, 'learnable', False) else self.likelihood_projection_matrix, + teacher_vocab_size, device, use_sparse_format=False + ) + projected_probs = projected_probs_full[:, :, top_indices].to(compute_dtype) + + # Step 3: Fused log-softmax on reduced teacher logits + teacher_log_probs = torch.log_softmax(teacher_logits_reduced / temperature, dim=-1).to(compute_dtype) + + if exact_token_match_only: + # Ultra-fast exact matching using vectorized operations + valid_positions = [] + + for batch_idx in range(batch_size): + batch_positions = [] + for alignment_pair in aligned_pairs[batch_idx]: + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if (start1 > 0 and start2 > 0 and + end1 - start1 == 1 and end2 - start2 == 1 and + start1-1 < student_seq_len and start2-1 < teacher_seq_len): + batch_positions.append((start1-1, start2-1)) + valid_positions.append(batch_positions) + + # If no valid positions, return zero + total_positions = sum(len(positions) for positions in valid_positions) + if total_positions == 0: + return torch.tensor(0.0, device=device, requires_grad=True), 0.0 + + # Vectorized gathering of valid positions + proj_list = [] + tgt_list = [] + for batch_idx, positions in enumerate(valid_positions): + for s_pos, t_pos in positions: + proj_list.append(projected_probs[batch_idx, s_pos]) + tgt_list.append(teacher_log_probs[batch_idx, t_pos]) + + projected_probs_masked = torch.stack(proj_list, dim=0) + target_log_probs_masked = torch.stack(tgt_list, dim=0) + + # In-place log operation for projected probs + projected_log_probs_masked = torch.log(projected_probs_masked + 1e-10) + + # Fused KL computation + if not reverse_kl: + loss_kl = torch.nn.functional.kl_div( + projected_log_probs_masked, target_log_probs_masked, + reduction="batchmean", log_target=True + ).to(student_logits.dtype) + else: + #print("reverse KL 3") + loss_kl = torch.nn.functional.kl_div( + target_log_probs_masked, projected_log_probs_masked, + reduction="batchmean", log_target=True + ).to(student_logits.dtype) + + + # Fast accuracy - use argmax directly on reduced vocab + with torch.no_grad(): + proj_argmax = projected_probs_masked.argmax(dim=-1) + tgt_argmax = torch.exp(target_log_probs_masked).argmax(dim=-1) + top1_accuracy = (proj_argmax == tgt_argmax).float().mean().item() + + else: + # Optimized chunk-based processing + max_n_chunks = min(student_seq_len, teacher_seq_len, 512) # Limit chunks for speed + + # Pre-allocate reusable tensors + if not hasattr(self, '_chunk_cache') or self._chunk_cache['batch_size'] != batch_size: + self._chunk_cache = { + 'batch_size': batch_size, + 'proj_mask': torch.zeros((batch_size, student_seq_len, max_n_chunks), + dtype=torch.bool, device=device), + 'tgt_mask': torch.zeros((batch_size, teacher_seq_len, max_n_chunks), + dtype=torch.bool, device=device) + } + + proj_mask = self._chunk_cache['proj_mask'] + tgt_mask = self._chunk_cache['tgt_mask'] + + # Clear masks (in-place) + proj_mask.zero_() + tgt_mask.zero_() + + # Fill masks efficiently (limit number of chunks processed) + for batch_idx in range(batch_size): + for chunk_idx, alignment_pair in enumerate(aligned_pairs[batch_idx][:max_n_chunks]): + s1text, s2text, start1, end1, start2, end2 = alignment_pair[:6] + if start1 != -1 and start2 != -1: + proj_mask[batch_idx, start1:end1, chunk_idx] = True + tgt_mask[batch_idx, start2:end2, chunk_idx] = True + + # Efficient chunk computation with reduced precision + proj_chunks = torch.bmm(proj_mask[:, :, :max_n_chunks].transpose(1,2).to(compute_dtype), projected_probs) + tgt_log_chunks = torch.bmm(tgt_mask[:, :, :max_n_chunks].transpose(1,2).to(compute_dtype), teacher_log_probs) + + # Fast normalization + proj_sizes = proj_mask[:, :, :max_n_chunks].sum(dim=1, keepdim=True).to(compute_dtype).transpose(1,2) + tgt_sizes = tgt_mask[:, :, :max_n_chunks].sum(dim=1, keepdim=True).to(compute_dtype).transpose(1,2) + + proj_chunks.div_(proj_sizes + 1e-10) + tgt_log_chunks.div_(tgt_sizes + 1e-10) + + # In-place renormalization and log + proj_chunks.div_(proj_chunks.sum(dim=-1, keepdim=True) + 1e-10) + proj_log_chunks = torch.log(proj_chunks + 1e-10) + + chunk_mask = (proj_sizes.squeeze(-1) > 0) & (tgt_sizes.squeeze(-1) > 0) + + # Compute loss + if not reverse_kl: + loss_kl = torch.nn.functional.kl_div( + proj_log_chunks, tgt_log_chunks, reduction="none", log_target=True + ) + else: + #print("reverse KL4") + loss_kl = torch.nn.functional.kl_div( + tgt_log_chunks, proj_log_chunks, reduction="none", log_target=True + ) + + if chunk_mask.sum() > 0: + loss_kl = (loss_kl * chunk_mask.unsqueeze(-1)).sum() / chunk_mask.sum() + else: + loss_kl = torch.tensor(0.0, device=device, requires_grad=True) + + loss_kl = loss_kl.to(student_logits.dtype) + + # Fast accuracy + with torch.no_grad(): + if chunk_mask.sum() > 0: + proj_argmax = proj_chunks.argmax(dim=-1) + tgt_argmax = torch.exp(tgt_log_chunks).argmax(dim=-1) + matches = ((proj_argmax == tgt_argmax) & chunk_mask).sum().item() + top1_accuracy = matches / chunk_mask.sum().item() + else: + top1_accuracy = 0.0 + + return loss_kl * (temperature ** 2), top1_accuracy diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 93eea10108..f77daf3407 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, NotRequired, TypedDict +import sys +if sys.version_info >= (3, 11): + from typing import Literal, NotRequired, TypedDict +else: + from typing import Literal, TypedDict + from typing_extensions import NotRequired class ResponseDatasetConfig(TypedDict): diff --git a/nemo_rl/data/datasets/eval_datasets/__init__.py b/nemo_rl/data/datasets/eval_datasets/__init__.py index 8386286c83..aa47340312 100644 --- a/nemo_rl/data/datasets/eval_datasets/__init__.py +++ b/nemo_rl/data/datasets/eval_datasets/__init__.py @@ -26,10 +26,12 @@ def load_eval_dataset(data_config): # mmlu if dataset_name.startswith("mmlu") and dataset_name != "mmlu_pro": + num_few_shot = data_config.get("num_few_shot", 0) if dataset_name == "mmlu": base_dataset = MMLUDataset( prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], + num_few_shot=num_few_shot, ) else: language = dataset_name.split("_")[1] @@ -37,6 +39,7 @@ def load_eval_dataset(data_config): language=language, prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], + num_few_shot=num_few_shot, ) elif dataset_name == "mmlu_pro": base_dataset = MMLUProDataset( diff --git a/nemo_rl/data/datasets/eval_datasets/mmlu.py b/nemo_rl/data/datasets/eval_datasets/mmlu.py index c9a373fc10..e94e90af91 100644 --- a/nemo_rl/data/datasets/eval_datasets/mmlu.py +++ b/nemo_rl/data/datasets/eval_datasets/mmlu.py @@ -14,6 +14,7 @@ """MMLU dataset and its variants.""" +from collections import defaultdict from typing import Any, Literal, Optional from datasets import load_dataset @@ -21,6 +22,8 @@ from nemo_rl.data import processors from nemo_rl.data.interfaces import TaskDataSpec +ANSWER_INDEX_TO_LETTER = {0: "A", 1: "B", 2: "C", 3: "D"} + class MMLUDataset: def __init__( @@ -44,6 +47,7 @@ def __init__( ] = "EN-US", prompt_file: Optional[str] = None, system_prompt_file: Optional[str] = None, + num_few_shot: int = 0, ): if language != "EN-US": data_files = f"https://openaipublic.blob.core.windows.net/simple-evals/mmlu_{language}.csv" @@ -58,6 +62,14 @@ def __init__( ) self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + if num_few_shot > 0: + few_shot_prefixes = self._build_few_shot_prefixes(num_few_shot) + self.rekeyed_ds = self.rekeyed_ds.map( + lambda ex: { + "few_shot_prefix": few_shot_prefixes.get(ex["subject"], "") + } + ) + self.task_spec = TaskDataSpec( task_name=f"MMLU_{language}", prompt_file=prompt_file, @@ -65,6 +77,34 @@ def __init__( ) self.processor = processors.multichoice_qa_processor + @staticmethod + def _build_few_shot_prefixes(num_few_shot: int) -> dict[str, str]: + """Load the MMLU dev set (cais/mmlu "all" validation split) and build + a per-subject formatted few-shot prefix string.""" + dev_ds = load_dataset("cais/mmlu", "all", split="validation") + + dev_by_subject: dict[str, list[dict[str, Any]]] = defaultdict(list) + for ex in dev_ds: + dev_by_subject[ex["subject"]].append(ex) + + prefixes: dict[str, str] = {} + for subject, examples in dev_by_subject.items(): + parts = [] + for fs_ex in examples[:num_few_shot]: + choices = fs_ex["choices"] + options_str = "\n".join( + f"{letter}) {choices[i]}" + for i, letter in enumerate(["A", "B", "C", "D"]) + ) + answer_letter = ANSWER_INDEX_TO_LETTER[fs_ex["answer"]] + parts.append( + f"Question: {fs_ex['question']}\nOptions:\n{options_str}\n" + f"Answer: {answer_letter}" + ) + prefixes[subject] = "\n\n".join(parts) + + return prefixes + def _rekey(self, data: dict[str, Any]): return { "question": data["Question"], diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index dec2b8e119..4448c44cb5 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -11,26 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any -from nemo_rl.data import ResponseDatasetConfig -from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset +from nemo_rl.data.datasets.response_datasets.arrow_text_dataset import ArrowTextDataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset -from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset -from nemo_rl.data.datasets.response_datasets.dapo_math import ( - DAPOMath17KDataset, - DAPOMathAIME2024Dataset, -) +from nemo_rl.data.datasets.response_datasets.dapo_math import DAPOMath17KDataset from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset -from nemo_rl.data.datasets.response_datasets.general_conversations_dataset import ( - GeneralConversationsJsonlDataset, -) from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset -from nemo_rl.data.datasets.response_datasets.gsm8k import GSM8KDataset from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset -from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset -from nemo_rl.data.datasets.response_datasets.nemotron_cascade2_sft import ( - NemotronCascade2SFTMathDataset, -) from nemo_rl.data.datasets.response_datasets.oai_format_dataset import ( OpenAIFormatDataset, ) @@ -42,41 +30,108 @@ from nemo_rl.data.datasets.response_datasets.response_dataset import ResponseDataset from nemo_rl.data.datasets.response_datasets.squad import SquadDataset from nemo_rl.data.datasets.response_datasets.tulu3 import Tulu3SftMixtureDataset +from nemo_rl.data.datasets.utils import get_extra_kwargs -DATASET_REGISTRY = { - # built-in datasets - "AIME2024": AIME2024Dataset, - "clevr-cogent": CLEVRCoGenTDataset, - "daily-omni": DailyOmniDataset, - "general-conversation-jsonl": GeneralConversationsJsonlDataset, - "DAPOMath17K": DAPOMath17KDataset, - "DAPOMathAIME2024": DAPOMathAIME2024Dataset, - "DeepScaler": DeepScalerDataset, - "geometry3k": Geometry3KDataset, - "HelpSteer3": HelpSteer3Dataset, - "open_assistant": OasstDataset, - "OpenMathInstruct-2": OpenMathInstruct2Dataset, - "refcoco": RefCOCODataset, - "squad": SquadDataset, - "tulu3_sft_mixture": Tulu3SftMixtureDataset, - "gsm8k": GSM8KDataset, - "Nemotron-Cascade-2-SFT-Math": NemotronCascade2SFTMathDataset, - # load from local JSONL file or HuggingFace - "openai_format": OpenAIFormatDataset, - "NemoGymDataset": NemoGymDataset, - "ResponseDataset": ResponseDataset, -} - -def load_response_dataset(data_config: ResponseDatasetConfig): +# TODO: refactor this to use the new processor interface and RawDataset interface. https://github.com/NVIDIA-NeMo/RL/issues/1552 +def load_response_dataset(data_config, seed: int = 42): """Loads response dataset.""" dataset_name = data_config["dataset_name"] - # load dataset - if dataset_name in DATASET_REGISTRY: - dataset_class = DATASET_REGISTRY[dataset_name] - dataset = dataset_class( - **data_config # pyrefly: ignore[missing-argument] `data_path` is required for some classes + # TODO @yukih: remove duplicated dataset_name (openmathinstruct2, clevr_cogent) + # for sft training + if dataset_name == "open_assistant": + base_dataset = OasstDataset( + output_dir="/tmp/open_assistant", + seed=seed, + ) + elif dataset_name == "squad": + base_dataset = SquadDataset() + elif dataset_name == "openmathinstruct2": + base_dataset = OpenMathInstruct2Dataset( + split=data_config["split"], + output_key=data_config["output_key"], + prompt_file=data_config["prompt_file"], + seed=seed, + ) + elif dataset_name == "clevr_cogent": + base_dataset = CLEVRCoGenTDataset( + split=data_config["split"], + prompt_file=data_config["prompt_file"], + ) + elif dataset_name == "openai_format": + base_dataset = OpenAIFormatDataset( + data_config["train_data_path"], + data_config["val_data_path"], + data_config["chat_key"], + data_config["system_key"], + data_config["system_prompt"], + data_config["tool_key"], + data_config["use_preserving_dataset"], + ) + elif dataset_name == "arrow_text": + base_dataset = ArrowTextDataset( + arrow_files=data_config["arrow_files"], + val_split=data_config.get("val_split", 0.05), + seed=seed, + text_key=data_config.get("text_key", "text"), + ) + # for rl training + elif dataset_name == "OpenMathInstruct-2": + print("Loading nvidia/OpenMathInstruct2Dataset for training and validation") + base_dataset: Any = OpenMathInstruct2Dataset(seed=seed) + elif dataset_name == "DeepScaler": + print( + "Loading agentica-org/DeepScaleR-Preview-Dataset for training and validation" + ) + base_dataset: Any = DeepScalerDataset(seed=seed) + elif dataset_name == "DAPOMath17K": + print( + "Loading BytedTsinghua-SIA/DAPO-Math-17k for training and AIME 2024 for validation" + ) + base_dataset: Any = DAPOMath17KDataset(seed=seed) + # for vlm rl training + elif dataset_name == "clevr-cogent": + base_dataset: Any = CLEVRCoGenTDataset( + split=data_config["split"], + ) + elif dataset_name == "refcoco": + base_dataset: Any = RefCOCODataset( + split=data_config["split"], + download_dir=data_config["download_dir"], + ) + elif dataset_name == "geometry3k": + base_dataset: Any = Geometry3KDataset( + split=data_config["split"], + ) + elif dataset_name == "tulu3_sft_mixture": + base_dataset: Any = Tulu3SftMixtureDataset( + test_size=data_config.get("test_size", 0.05), + prompt_file=data_config.get("prompt_file", None), + max_samples=data_config.get("max_samples", None), + seed=seed, + ) + elif dataset_name == "HelpSteer3": + base_dataset: Any = HelpSteer3Dataset() + # fall back to load from JSON file + elif dataset_name == "ResponseDataset": + if "train_data_path" not in data_config: + raise ValueError( + "train_data_path is required when dataset_name is not one of the built-ins." + ) + extra_kwargs = get_extra_kwargs( + data_config, + [ + "val_data_path", + "input_key", + "output_key", + "train_split", + "val_split", + ], + ) + base_dataset = ResponseDataset( + train_data_path=data_config["train_data_path"], + **extra_kwargs, ) else: raise ValueError( @@ -85,33 +140,35 @@ def load_response_dataset(data_config: ResponseDatasetConfig): "or set dataset_name=ResponseDataset to load from local JSONL file or HuggingFace." ) - # bind prompt, system prompt and data processor - dataset.set_task_spec(data_config) - # Remove this after the data processor is refactored. https://github.com/NVIDIA-NeMo/RL/issues/1658 - dataset.set_processor() + base_dataset.set_task_spec(data_config) + # Skip sft datasets, the run_sft.py has not been refactored yet. + # TODO: refactor run_sft.py to use the new processor interface. https://github.com/NVIDIA-NeMo/RL/issues/1552 + if dataset_name not in [ + "open_assistant", + "squad", + "openmathinstruct2", + "clevr_cogent", + "openai_format", + "tulu3_sft_mixture", + "arrow_text", + ]: + base_dataset.set_processor() - return dataset + return base_dataset __all__ = [ - "AIME2024Dataset", + "ArrowTextDataset", "CLEVRCoGenTDataset", - "DailyOmniDataset", - "GeneralConversationsJsonlDataset", - "DAPOMath17KDataset", - "DAPOMathAIME2024Dataset", - "GSM8KDataset", "DeepScalerDataset", + "DAPOMath17KDataset", "Geometry3KDataset", - "HelpSteer3Dataset", - "NemoGymDataset", - "NemotronCascade2SFTMathDataset", - "OasstDataset", "OpenAIFormatDataset", + "OasstDataset", "OpenMathInstruct2Dataset", "RefCOCODataset", "ResponseDataset", "SquadDataset", "Tulu3SftMixtureDataset", - "load_response_dataset", + "HelpSteer3Dataset", ] diff --git a/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py new file mode 100644 index 0000000000..8e46519671 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/arrow_text_dataset.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Arrow Text Dataset for loading arrow files with 'text' column.""" + +import glob +from typing import Any, Optional + +from datasets import Dataset, load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +class ArrowTextDataset(RawDataset): + """Dataset class for loading arrow files containing raw text. + + This class loads arrow files with a 'text' column and converts them to + the messages format expected by SFT training. + + The text is wrapped as an assistant message: + {"messages": [{"role": "assistant", "content": }]} + + This format allows training on all tokens (language modeling style). + + Args: + arrow_files: Path pattern (glob) or list of arrow file paths + val_split: Fraction of data to use for validation (default: 0.05) + seed: Random seed for train/val split + text_key: Key for text column in arrow files (default: "text") + + Example config: + data: + dataset_name: "arrow_text" + arrow_files: "/path/to/data/*.arrow" + val_split: 0.05 + max_input_seq_length: 4096 + """ + + def __init__( + self, + arrow_files: str | list[str], + val_split: float = 0.05, + seed: int = 42, + text_key: str = "text", + ): + # Don't call super().__init__() since RawDataset raises NotImplementedError + self.seed = seed + self.text_key = text_key + self.task_name = "arrow_text_dataset" + + # Resolve glob pattern if string + if isinstance(arrow_files, str): + file_list = glob.glob(arrow_files) + if not file_list: + raise ValueError(f"No arrow files found matching pattern: {arrow_files}") + else: + file_list = arrow_files + + print(f"Loading {len(file_list)} arrow files...") + dataset = load_dataset("arrow", data_files=file_list, split="train") + print(f" ✓ Loaded {len(dataset)} total samples") + + # Verify text column exists + if self.text_key not in dataset.column_names: + raise ValueError( + f"Column '{self.text_key}' not found in arrow files. " + f"Available columns: {dataset.column_names}" + ) + + # Convert text to messages format + def text_to_messages(example: dict[str, Any]) -> dict[str, Any]: + """Convert raw text to messages format for SFT training.""" + text = example[self.text_key] + return { + "messages": [{"role": "assistant", "content": text}] + } + + formatted_dataset = dataset.map(text_to_messages, remove_columns=dataset.column_names) + + # Split into train/validation + if val_split > 0: + split = formatted_dataset.train_test_split(test_size=val_split, seed=seed) + train_dataset = split["train"] + val_dataset = split["test"] + else: + train_dataset = formatted_dataset + # Create a small validation set from the end + val_dataset = formatted_dataset.select(range(min(100, len(formatted_dataset)))) + + print(f" ✓ Train: {len(train_dataset)}, Validation: {len(val_dataset)}") + + self.formatted_ds = { + "train": train_dataset, + "validation": val_dataset, + } diff --git a/nemo_rl/data/interfaces.py b/nemo_rl/data/interfaces.py index f3f88b3b5e..91e0638857 100644 --- a/nemo_rl/data/interfaces.py +++ b/nemo_rl/data/interfaces.py @@ -13,7 +13,12 @@ # limitations under the License. import os from dataclasses import dataclass -from typing import Any, NotRequired, Optional, Protocol, TypedDict, Union +import sys +if sys.version_info >= (3, 11): + from typing import Any, NotRequired, Optional, Protocol, TypedDict, Union +else: + from typing import Any, Optional, Protocol, TypedDict, Union + from typing_extensions import NotRequired import torch from transformers.tokenization_utils_base import PreTrainedTokenizerBase diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 2b9a3900f1..34c80a11d5 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -657,10 +657,15 @@ def vlm_hf_data_processor( def _construct_multichoice_prompt( - prompt: str, question: str, options: dict[str, str] + prompt: str, + question: str, + options: dict[str, str], + few_shot_prefix: str = "", ) -> str: - """Construct prompt from question and options.""" + """Construct prompt from question and options, with optional few-shot examples.""" output = prompt + if few_shot_prefix: + output += f"\n\n{few_shot_prefix}" output += f"\n\nQuestion: {question}\nOptions:\n" output += "\n".join( [ @@ -709,7 +714,10 @@ def multichoice_qa_processor( # user prompt if task_data_spec.prompt: question = _construct_multichoice_prompt( - task_data_spec.prompt, question, options + task_data_spec.prompt, + question, + options, + few_shot_prefix=datum_dict.get("few_shot_prefix", ""), ) user_message = {"role": "user", "content": question} message = tokenizer.apply_chat_template( diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 95677873a4..7bb62b0ab8 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -33,6 +33,7 @@ "nemo_rl.models.generation.sglang.sglang_worker.SGLangGenerationWorker": SGLANG_EXECUTABLE, "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.FSDP, "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL, + "nemo_rl.models.policy.workers.dtensor_distillation_worker.DTensorDistillationWorker": PY_EXECUTABLES.AUTOMODEL, "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.math_environment.MathMultiRewardEnvironment": PY_EXECUTABLES.SYSTEM, diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index 6da76d04db..ba14e66e84 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -15,7 +15,12 @@ import io import logging import re -from typing import Any, NotRequired, TypedDict, Union +import sys +if sys.version_info >= (3, 11): + from typing import Any, NotRequired, TypedDict, Union +else: + from typing import Any, TypedDict, Union + from typing_extensions import NotRequired import ray import torch diff --git a/nemo_rl/environments/reward_model_environment.py b/nemo_rl/environments/reward_model_environment.py index eee7af9a16..84d3952e47 100644 --- a/nemo_rl/environments/reward_model_environment.py +++ b/nemo_rl/environments/reward_model_environment.py @@ -13,7 +13,12 @@ # limitations under the License. import os -from typing import Any, Dict, List, NotRequired, Optional, Tuple, TypedDict +import sys +if sys.version_info >= (3, 11): + from typing import Any, Dict, List, NotRequired, Optional, Tuple, TypedDict +else: + from typing import Any, Dict, List, Optional, Tuple, TypedDict + from typing_extensions import NotRequired import ray import torch diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index df82c7d1af..6241574c21 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Dict, NotRequired, TypedDict +import sys +if sys.version_info >= (3, 11): + from typing import Any, Dict, NotRequired, TypedDict +else: + from typing import Any, Dict, TypedDict + from typing_extensions import NotRequired from hydra.utils import get_object diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 037b4880f5..3de92fa805 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, NotRequired, TypedDict, Union +import sys +if sys.version_info >= (3, 11): + from typing import Any, NotRequired, TypedDict, Union +else: + from typing import Any, TypedDict, Union + from typing_extensions import NotRequired import ray import torch diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index 857ed177c2..5172ee47f7 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, NotRequired, TypedDict +import sys +if sys.version_info >= (3, 11): + from typing import Any, Literal, NotRequired, TypedDict +else: + from typing import Any, Literal, TypedDict + from typing_extensions import NotRequired from nemo_rl.models.generation.interfaces import GenerationConfig diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index ec4c9e66bb..5f90650881 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, NotRequired, TypedDict, Union +import sys +if sys.version_info >= (3, 11): + from typing import Any, Literal, NotRequired, TypedDict, Union +else: + from typing import Any, Literal, TypedDict, Union + from typing_extensions import NotRequired from nemo_rl.models.generation.interfaces import GenerationConfig diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 8acd808b11..5e24c4083b 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -14,7 +14,6 @@ import os import warnings from collections import defaultdict -from contextlib import nullcontext from typing import Any, Optional, Union import numpy as np @@ -23,7 +22,7 @@ from ray.util.queue import Queue as RayQueue from transformers import AutoProcessor, PreTrainedTokenizerBase -from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import ( BatchedDataDict, DynamicBatchingArgs, @@ -52,7 +51,6 @@ get_default_hf_config, get_theoretical_tflops, ) -from nemo_rl.utils.timer import Timer PathLike = Union[str, "os.PathLike[Any]"] @@ -70,41 +68,28 @@ def __init__( optimizer_path: Optional[PathLike] = None, init_reference_model: bool = True, processor: Optional[AutoProcessor] = None, - worker_extension_cls_fqn: Optional[str] = None, + worker_builder_cls_override: Optional[str] = None, + extra_worker_kwargs: Optional[dict[str, Any]] = None, ): if weights_path: weights_path = os.path.abspath(weights_path) if optimizer_path: optimizer_path = os.path.abspath(optimizer_path) - worker_builder_cls_fqn: str + worker_builder_cls: str tp_size = 1 pp_size = 1 cp_size = 1 - use_v2 = False megatron_enable = bool(config.get("megatron_cfg", {}).get("enabled", False)) dtensor_enable = bool(config.get("dtensor_cfg", {}).get("enabled", False)) - draft_enabled = bool(config.get("draft", {}).get("enabled", False)) if megatron_enable and dtensor_enable: raise ValueError( "Configure either Megatron (policy.megatron_cfg.enabled=true) or " "DTensor (policy.dtensor_cfg.enabled=true), not both." ) - if draft_enabled and not megatron_enable: - raise ValueError( - "policy.draft.enabled=true is only supported with the Megatron backend. " - "Set policy.megatron_cfg.enabled=true or disable policy.draft." - ) - if draft_enabled and bool( - config.get("sequence_packing", {}).get("enabled", False) - ): - raise ValueError( - "policy.draft.enabled=true does not support sequence packing yet. " - "Disable policy.sequence_packing.enabled or policy.draft." - ) if megatron_enable: - worker_builder_cls_fqn = "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" + worker_builder_cls = "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" tp_size = config["megatron_cfg"]["tensor_model_parallel_size"] pp_size = config["megatron_cfg"]["pipeline_model_parallel_size"] cp_size = config["megatron_cfg"]["context_parallel_size"] @@ -127,7 +112,7 @@ def __init__( # Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility) use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) if use_v2: - worker_builder_cls_fqn = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" + worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" if "TORCH_CUDA_ARCH_LIST" not in os.environ: warnings.warn( @@ -139,43 +124,17 @@ def __init__( config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False) is False ), "LoRA is not supported for DTensorPolicyWorker V1" - worker_builder_cls_fqn = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" + worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" tp_size = config["dtensor_cfg"]["tensor_parallel_size"] cp_size = config["dtensor_cfg"]["context_parallel_size"] env_vars = config["dtensor_cfg"].get("env_vars", {}) - # If a worker extension class is provided, use it instead of the default worker builder class - if worker_extension_cls_fqn is not None: - print( - f"Using worker extension class: {worker_extension_cls_fqn}, please make sure it is a subclass of {worker_builder_cls_fqn}." - ) - worker_builder_cls_fqn = worker_extension_cls_fqn - # Validate world_size compatibility with parallelism configuration model_parallel_size = pp_size * cp_size * tp_size actual_world_size = cluster.world_size() - if ( - not bool(os.environ.get("NRL_IGNORE_TP_ACCURACY_CHECK")) - and "logprob_batch_size" in config - and tp_size >= 4 - ): - sep_line = "\n" + ("-" * 80) - assert config["train_micro_batch_size"] == config["logprob_batch_size"], ( - f"{sep_line}\n" - "There is a known batch-variant accuracy issue with TP>=4 for both DTensor and Megatron backend.\n" - "See https://docs.nvidia.com/nemo/rl/latest/guides/dtensor-tp-accuracy.html#root-cause for more details.\n" - "\n" - "Please choose either of the following solutions to avoid this issue:\n" - "1. Set tp_size to 1 or 2. (tensor_parallel_size for DTensor, or tensor_model_parallel_size for Megatron)\n" - "2. Set policy.train_micro_batch_size and policy.logprob_batch_size to be the same value.\n" - "3. Set loss_fn.force_on_policy_ratio=true to force ratio=1.0, this requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt.\n" - "4. Set NRL_IGNORE_TP_ACCURACY_CHECK=1 to bypass this check. (not recommended)" - f"{sep_line}\n" - ) - if actual_world_size < model_parallel_size: raise ValueError( f"World size ({actual_world_size}) is insufficient for the parallelism configuration. " @@ -208,29 +167,23 @@ def __init__( ], ) - pre_init_queue = RayQueue() + if worker_builder_cls_override is not None: + worker_builder_cls = worker_builder_cls_override - worker_kwargs = dict( + pre_init_queue = RayQueue() + _extra = extra_worker_kwargs or {} + worker_builder = RayWorkerBuilder( + worker_builder_cls, + config, + tokenizer=tokenizer, + processor=processor, init_optimizer=init_optimizer, weights_path=weights_path, optimizer_path=optimizer_path, init_reference_model=init_reference_model, worker_sharding_annotations=self.sharding_annotations, pre_init_communication_queue=pre_init_queue, - ) - - if use_v2: - # DTensor v2 workers reconstruct tokenizer/processor locally to avoid - # pickling across incompatible transformers versions (v4 head → v5 worker). - config["tokenizer"]["use_processor"] = processor is not None - else: - worker_kwargs["tokenizer"] = tokenizer - worker_kwargs["processor"] = processor - - worker_builder = RayWorkerBuilder( - worker_builder_cls_fqn, - config, - **worker_kwargs, + **_extra, ) if cluster._sorted_bundle_indices is not None: @@ -261,7 +214,7 @@ def __init__( env_vars=env_vars or {}, ) - if config["dynamic_batching"]["enabled"]: + if config.get("dynamic_batching", {}).get("enabled", False): assert pp_size == 1, ( "Dynamic batching is only supported for single pipeline parallel stage" ) @@ -274,7 +227,7 @@ def __init__( ], "max_tokens_per_microbatch": 0, # Override this in each different call (presumably different sizes) } - assert not config["sequence_packing"]["enabled"], ( + assert not config.get("sequence_packing", {}).get("enabled", False), ( "Dynamic Batching is exclusive of Sequence Packing. Please disable Sequence Packing to use Dynamic Batching" ) else: @@ -289,16 +242,18 @@ def __init__( self.flops_tracker = None print(f"FLOPS tracker not supported for model {config['model_name']}: {e}") - if config["sequence_packing"]["enabled"]: + if config.get("sequence_packing", {}).get("enabled", False): self.use_sequence_packing = True - sequence_length_pad_multiple = config["make_sequence_length_divisible_by"] + sequence_length_pad_multiple = ( + cp_size * 2 * tp_size if cp_size > 1 else tp_size + ) self.sequence_packing_args: SequencePackingArgs = { "algorithm": config["sequence_packing"]["algorithm"], "input_key": "input_ids", "input_lengths_key": "input_lengths", "sequence_length_pad_multiple": sequence_length_pad_multiple, } - assert not config["dynamic_batching"]["enabled"], ( + assert not config.get("dynamic_batching", {}).get("enabled", False), ( "Sequence Packing is exclusive of Dynamic Batching. Please disable Dynamic Batching" ) else: @@ -306,44 +261,6 @@ def __init__( self.cfg = config - def run_all_workers_single_data(self, method_name: str, *args, **kwargs) -> Any: - """Run a method on all workers in parallel with the same data. - - Mainly used for worker extension classes. - - Args: - method_name: The name of the method to run. - *args: The positional arguments to pass to the method. - **kwargs: The keyword arguments to pass to the method. - - Returns: - The results of the method run on all workers. - """ - futures = self.worker_group.run_all_workers_single_data( - method_name, *args, **kwargs - ) - results = ray.get(futures) - return results - - def run_all_workers_multiple_data(self, method_name: str, *args, **kwargs) -> Any: - """Run a method on all workers in parallel with different data. - - Mainly used for worker extension classes. - - Args: - method_name: The name of the method to run. - *args: The positional arguments to pass to the method. - **kwargs: The keyword arguments to pass to the method. - - Returns: - The results of the method run on all workers. - """ - futures = self.worker_group.run_all_workers_multiple_data( - method_name, *args, **kwargs - ) - results = ray.get(futures) - return results - def init_collective( self, ip: str, port: int, world_size: int, *, train_world_size: int ) -> list[ray.ObjectRef]: @@ -359,9 +276,7 @@ def init_collective( return futures def get_logprobs( - self, - data: BatchedDataDict[GenerationDatumSpec], - timer: Optional[Timer] = None, + self, data: BatchedDataDict[GenerationDatumSpec] ) -> BatchedDataDict[LogprobOutputSpec]: """Get the logprobs of the model for a data dict. @@ -374,52 +289,46 @@ def get_logprobs( sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] - with timer.time("get_logprobs/shard_data") if timer else nullcontext(): - if self.use_dynamic_batches: - self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ - "dynamic_batching" - ]["logprob_mb_tokens"] - sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - dynamic_batching_args=self.dynamic_batching_args, - ) - elif self.use_sequence_packing: - self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ - "sequence_packing" - ]["logprob_mb_tokens"] - # we just shard into DP shards here as Sequence packing allows for CP. - sharded_data, unsorted_data_indices = data.shard_by_batch_size( - dp_size, - batch_size=None, - sequence_packing_args=self.sequence_packing_args, - ) - else: - sharded_data = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - ) - - with ( - timer.time("get_logprobs/submit_logprob_futures") - if timer - else nullcontext() - ): - futures = self.worker_group.run_all_workers_sharded_data( - "get_logprobs", - data=sharded_data, - in_sharded_axes=["data_parallel"], - replicate_on_axes=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - output_is_replicated=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + dynamic_batching_args=self.dynamic_batching_args, + ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + # we just shard into DP shards here as Sequence packing allows for CP. + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) + else: + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, ) + + futures = self.worker_group.run_all_workers_sharded_data( + "get_logprobs", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + ) logprobs: BatchedDataDict[LogprobOutputSpec] = BatchedDataDict.from_batches( self.worker_group.get_all_worker_results(futures) ) @@ -435,7 +344,6 @@ def get_reference_policy_logprobs( self, data: BatchedDataDict[GenerationDatumSpec], micro_batch_size: Optional[int] = None, - timer: Optional[Timer] = None, ) -> BatchedDataDict[ReferenceLogprobOutputSpec]: """Get the logprobs of the reference policy for a data dict. @@ -444,58 +352,46 @@ def get_reference_policy_logprobs( dp_size = self.sharding_annotations.get_axis_size("data_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] - with ( - timer.time("get_reference_policy_logprobs/shard_data") - if timer - else nullcontext() - ): - if self.use_dynamic_batches: - self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ - "dynamic_batching" - ]["logprob_mb_tokens"] - sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - dynamic_batching_args=self.dynamic_batching_args, - ) - elif self.use_sequence_packing: - self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ - "sequence_packing" - ]["logprob_mb_tokens"] - sharded_data, unsorted_data_indices = data.shard_by_batch_size( - dp_size, - batch_size=None, - sequence_packing_args=self.sequence_packing_args, - ) - else: - sharded_data = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - ) - - with ( - timer.time( - "get_reference_policy_logprobs/submit_reference_policy_logprob_futures" + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + dynamic_batching_args=self.dynamic_batching_args, ) - if timer - else nullcontext() - ): - futures = self.worker_group.run_all_workers_sharded_data( - "get_reference_policy_logprobs", - data=sharded_data, - in_sharded_axes=["data_parallel"], - replicate_on_axes=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - output_is_replicated=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - common_kwargs={"micro_batch_size": micro_batch_size}, + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, ) + else: + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + ) + + futures = self.worker_group.run_all_workers_sharded_data( + "get_reference_policy_logprobs", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"micro_batch_size": micro_batch_size}, + ) logprobs: BatchedDataDict[ReferenceLogprobOutputSpec] = ( BatchedDataDict.from_batches( self.worker_group.get_all_worker_results(futures) @@ -514,62 +410,56 @@ def get_topk_logits( data: BatchedDataDict[GenerationDatumSpec], k: int, micro_batch_size: Optional[int] = None, - timer: Optional[Timer] = None, ) -> BatchedDataDict[TopkLogitsOutputSpec]: """Dispatch get_topk_logits to workers (no CP/packed support initially).""" dp_size = self.sharding_annotations.get_axis_size("data_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] - with timer.time("get_topk_logits/shard_data") if timer else nullcontext(): - if self.use_dynamic_batches: - self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ - "dynamic_batching" - ]["logprob_mb_tokens"] - sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - dynamic_batching_args=self.dynamic_batching_args, - ) - elif self.use_sequence_packing: - self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ - "sequence_packing" - ]["logprob_mb_tokens"] - # we just shard into DP shards here as Sequence packing allows for CP. - sharded_data, unsorted_data_indices = data.shard_by_batch_size( - dp_size, - batch_size=None, - sequence_packing_args=self.sequence_packing_args, - ) - else: - sharded_data = data.shard_by_batch_size( # type: ignore - dp_size, - batch_size=None, - ) - - with ( - timer.time("get_topk_logits/submit_topk_logits_futures") - if timer - else nullcontext() - ): - futures = self.worker_group.run_all_workers_sharded_data( - "get_topk_logits", - data=sharded_data, - in_sharded_axes=["data_parallel"], - replicate_on_axes=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - output_is_replicated=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["logprob_mb_tokens"] + sharded_data, unsorted_data_indices = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + dynamic_batching_args=self.dynamic_batching_args, ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + # we just shard into DP shards here as Sequence packing allows for CP. + sharded_data, unsorted_data_indices = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) + else: + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + ) + + futures = self.worker_group.run_all_workers_sharded_data( + "get_topk_logits", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, + ) # Avoid BatchedDataDict.from_batches here because it flattens rows for tensors with ndim>2 ([B,S,k] -> [B,S*k]). worker_batches = self.worker_group.get_all_worker_results(futures) + all_topk_logits = [wb["topk_logits"] for wb in worker_batches] all_topk_indices = [wb["topk_indices"] for wb in worker_batches] @@ -582,6 +472,84 @@ def get_topk_logits( return stacked + def teacher_forward( + self, + data: BatchedDataDict, + k: int, + micro_batch_size: Optional[int] = None, + ) -> None: + """Dispatch teacher_forward to workers (distillation worker only). + + Each worker runs the teacher model forward pass and stores top-k + logprobs in GPU IPC buffers. No data is returned through Ray — + the subsequent train() call reads from self.teacher_logits on + each worker directly. + """ + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["logprob_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=None, + dynamic_batching_args=self.dynamic_batching_args, + ) + else: + sharded_data = data.shard_by_batch_size( + dp_size, + batch_size=None, + ) + + futures = self.worker_group.run_all_workers_sharded_data( + "teacher_forward", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"k": k, "micro_batch_size": micro_batch_size}, + ) + self.worker_group.get_all_worker_results(futures) + + def init_cross_tokenizer_loss_fn(self, loss_config, token_aligner_config) -> None: + """Have each worker build its own CrossTokenizerDistillationLossFn from config + shared filesystem.""" + futures = self.worker_group.run_all_workers_single_data( + "init_cross_tokenizer_loss_fn", + loss_config=loss_config, + token_aligner_config=token_aligner_config, + ) + ray.get(futures) + + def update_cross_tokenizer_data(self, teacher_input_ids, aligned_pairs) -> None: + """Update per-step cross-tokenizer data on all workers' cached loss functions. + + Shards the data so each worker only receives its slice, not the full batch. + """ + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + batch_size = teacher_input_ids.shape[0] + shard_size = batch_size // dp_size + + futures = self.worker_group.run_all_workers_multiple_data( + "update_cross_tokenizer_data", + teacher_input_ids=[ + teacher_input_ids[i * shard_size : (i + 1) * shard_size] + for i in range(dp_size) + ], + aligned_pairs=[ + aligned_pairs[i * shard_size : (i + 1) * shard_size] + for i in range(dp_size) + ], + ) + ray.get(futures) + def train( self, data: BatchedDataDict[Any], @@ -589,37 +557,39 @@ def train( eval_mode: bool = False, gbs: Optional[int] = None, mbs: Optional[int] = None, - timer: Optional[Timer] = None, + is_teacher: bool = False, + teacher_logits: Optional[Any] = None, + topk_logits: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" batch_size = gbs or self.cfg["train_global_batch_size"] micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + # Shard and replicate the batch dp_size = self.sharding_annotations.get_axis_size("data_parallel") - with timer.time("policy_training/sharding_data") if timer else nullcontext(): - if self.use_dynamic_batches: - self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ - "dynamic_batching" - ]["train_mb_tokens"] - sharded_data, _ = data.shard_by_batch_size( - dp_size, - batch_size=batch_size, - dynamic_batching_args=self.dynamic_batching_args, - ) - elif self.use_sequence_packing: - self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ - "sequence_packing" - ]["train_mb_tokens"] - sharded_data, _ = data.shard_by_batch_size( - dp_size, - batch_size=batch_size, - sequence_packing_args=self.sequence_packing_args, - ) - else: - sharded_data = data.shard_by_batch_size( - dp_size, - batch_size=batch_size, - ) + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["train_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + dynamic_batching_args=self.dynamic_batching_args, + ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["train_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + sequence_packing_args=self.sequence_packing_args, + ) + else: + sharded_data = data.shard_by_batch_size( + dp_size, + batch_size=batch_size, + ) if self.flops_tracker is not None: self.flops_tracker.reset() @@ -628,38 +598,50 @@ def train( self.flops_tracker.track_batch(input_lengths.tolist()) # Train each shard in parallel - with ( - timer.time("policy_training/submit_training_futures") - if timer - else nullcontext() - ): - futures = self.worker_group.run_all_workers_sharded_data( - "train", - data=sharded_data, - in_sharded_axes=["data_parallel"], - replicate_on_axes=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - output_is_replicated=[ - "context_parallel", - "tensor_parallel", - "pipeline_parallel", - ], - common_kwargs={ - "loss_fn": loss_fn, - "eval_mode": eval_mode, - "gbs": batch_size, - "mbs": micro_batch_size, - }, - ) + common_kwargs = { + "loss_fn": loss_fn, + "eval_mode": eval_mode, + "gbs": batch_size, + "mbs": micro_batch_size, + "is_teacher": is_teacher, + "teacher_logits": teacher_logits, + "topk_logits": topk_logits, + } + + if is_teacher: + output_replicated = [ + "context_parallel", + "pipeline_parallel", + ] + else: + output_replicated = [ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ] + + futures = self.worker_group.run_all_workers_sharded_data( + "train", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=output_replicated, + common_kwargs=common_kwargs, + ) results = self.worker_group.get_all_worker_results(futures) - # Aggregate the results + if is_teacher: + return results + aggregated_results = { "loss": results[0]["global_loss"], "grad_norm": results[0]["grad_norm"], + "kl_loss": sum(results[0]['all_mb_metrics']['kl_loss']) if 'kl_loss' in results[0]['all_mb_metrics'] else 0.0, + "nll_loss": sum(results[0]['all_mb_metrics']['nll_loss']) if 'nll_loss' in results[0]['all_mb_metrics'] else 0.0, } if "moe_metrics" in results[0]: aggregated_results["moe_metrics"] = results[0]["moe_metrics"] @@ -893,19 +875,13 @@ def stream_weights_via_ipc_zmq( ) return futures - def stream_weights_via_http( - self, sglang_url_to_gpu_uuids: dict[str, list[str]] - ) -> list[ray.ObjectRef]: - """Send the weights to SGLang servers via HTTP API. - - Args: - sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses - """ + def get_model_config(self): + """Get the model configuration from workers.""" futures = self.worker_group.run_all_workers_single_data( - "stream_weights_via_http", - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + "return_model_config" ) - return futures + results = ray.get(futures) + return results[0] def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 022335f7d0..b1d9bababe 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -562,8 +562,16 @@ def train( eval_mode: bool = False, gbs: Optional[int] = None, mbs: Optional[int] = None, + is_teacher: bool = False, + teacher_logits: Optional[Any] = None, + topk_logits: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" + if is_teacher or teacher_logits is not None: + raise NotImplementedError( + "IPC-based teacher/student distillation requires DTensorPolicyWorkerV2 " + "(set dtensor_cfg._v2=true in config)" + ) if gbs is None: gbs = self.cfg["train_global_batch_size"] if mbs is None: diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 2fa8a8e604..dd79abfbd0 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -12,52 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import gc +import itertools +import os import warnings +from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Any, Generator, Optional +from typing import Any, Generator, Optional, cast import ray import torch -from nemo_automodel.components._peft.lora import LinearLoRA +from accelerate import init_empty_weights +from hydra.utils import get_class +from nemo_automodel import ( + NeMoAutoModelForSequenceClassification, +) +from nemo_automodel._transformers.registry import ModelRegistry +from nemo_automodel.components._peft.lora import ( + PeftConfig, + apply_lora_to_linear_modules, +) +from nemo_automodel.components.config.loader import _resolve_target from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, + get_train_context, ) -from nemo_automodel.components.distributed.cp_utils import ( - get_train_context as get_train_context_automodel, +from nemo_automodel.components.distributed.fsdp2 import ( + FSDP2Manager, ) from nemo_automodel.components.distributed.tensor_utils import ( get_cpu_state_dict, to_local_if_dtensor, ) +from nemo_automodel.components.moe.parallelizer import ( + parallelize_model as moe_parallelize_model, +) from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm from torch import nn -from torch.distributed.tensor import DTensor +from torch.distributed.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, +) +from torch.distributed.tensor import DTensor, Shard +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + PreTrainedModel, +) +from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM -from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams -from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.models.automodel.checkpoint import AutomodelCheckpointManager -from nemo_rl.models.automodel.data import ( - check_sequence_dim, - get_microbatch_iterator, - process_global_batch, -) -from nemo_rl.models.automodel.setup import ( - setup_distributed, - setup_model_and_optimizer, - setup_reference_model_state, - validate_and_prepare_config, +from nemo_rl.distributed.model_utils import ( + _compute_distributed_log_softmax, + allgather_cp_sharded_tensor, + distributed_vocab_topk, + get_logprobs_from_vocab_parallel_logits, ) -from nemo_rl.models.automodel.train import ( - LogprobsPostProcessor, - LossPostProcessor, - ScorePostProcessor, - TopkLogitsPostProcessor, - aggregate_training_statistics, - automodel_forward_backward, - forward_with_post_processing_fn, +from nemo_rl.models.policy.utils import get_handle_from_tensor, rebuild_cuda_tensor_from_ipc +from nemo_rl.models.huggingface.common import ( + get_flash_attention_kwargs, + pack_sequences, ) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -65,132 +82,32 @@ LogprobOutputSpec, ScoreOutputSpec, ) -from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker +from nemo_rl.models.policy.utils import ( + configure_dynamo_cache, + get_runtime_env_for_policy_worker, + resolve_model_class, +) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker from nemo_rl.models.policy.workers.patches import ( + apply_torch_aten_alias_tensor_patch, apply_transformer_engine_patch, ) +from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer - -def dtensor_params_generator( - model: nn.Module, target_dtype: torch.dtype -) -> Generator[tuple[str, torch.Tensor], None, None]: - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors and adapting to HF format. - - Args: - model: The model whose parameters to generate. - target_dtype: The dtype to convert tensors to. - peft_config: Optional LoRA config for filtering which layers to merge. - - Yields: - Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous. - """ - module_map = dict(model.named_modules()) - for name, tensor in model.state_dict().items(): - if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): - continue - full_tensor = tensor.full_tensor() if isinstance(tensor, DTensor) else tensor - merged_tensor = _maybe_merge_lora_weight(module_map, name, full_tensor) - - adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, merged_tensor) - for adapted_fqn, adapted_tensor in adapted_fqn_tensors: - # Convert to target dtype - yield ( - adapted_fqn, - adapted_tensor.to(target_dtype, non_blocking=True).contiguous(), - ) - del adapted_tensor - del adapted_fqn_tensors - del merged_tensor - del full_tensor - - -@torch.no_grad() -def _maybe_merge_lora_weight( - module_map: dict[str, nn.Module], - fqn: str, - tensor: torch.Tensor, -) -> torch.Tensor: - if not fqn.endswith(".weight"): - return tensor - module_name = fqn[: -len(".weight")] - module = module_map.get(module_name) - if not isinstance(module, LinearLoRA): - return tensor - if not (hasattr(module, "lora_A") and hasattr(module, "lora_B")): - return tensor - - lora_a = ( - module.lora_A.weight.full_tensor() - if isinstance(module.lora_A.weight, DTensor) - else module.lora_A.weight - ) - lora_b = ( - module.lora_B.weight.full_tensor() - if isinstance(module.lora_B.weight, DTensor) - else module.lora_B.weight - ) - lora_a = lora_a.to(device=tensor.device, dtype=tensor.dtype) - lora_b = lora_b.to(device=tensor.device, dtype=tensor.dtype) - scale = getattr(module, "scale", None) - - if scale is None and hasattr(module, "alpha") and hasattr(module, "dim"): - scale = module.alpha / module.dim - if scale is None: - scale = 1.0 - - return tensor + torch.matmul(lora_b, lora_a) * scale - - -def _maybe_adapt_tensor_to_hf( - model_part: nn.Module, fqn: str, tensor: torch.Tensor, quantization: bool = False -) -> list[tuple[str, torch.Tensor]]: - adapter = getattr(model_part, "state_dict_adapter", None) - if adapter: - return adapter.convert_single_tensor_to_hf( - fqn, - tensor, - exclude_key_regex=r".*_extra_state.*", - quantization=quantization, - ) - return [(fqn, tensor)] - - -@contextlib.contextmanager -def get_train_context( - cp_size: int, - cp_mesh: Any, - cp_buffers: list, - sequence_dim: int, - dtype: torch.dtype, - autocast_enabled: bool = True, -) -> Generator[None, None, None]: - """Create combined context manager for training with context parallel and autocast.""" - with contextlib.ExitStack() as stack: - context_parallel_ctx = None - if cp_size > 1: - # Create context parallel context - context_parallel_ctx = create_context_parallel_ctx( - cp_mesh=cp_mesh, - cp_buffers=cp_buffers, - cp_seq_dims=[sequence_dim] * len(cp_buffers), - cp_no_restore_buffers=set(cp_buffers), - ) - - stack.enter_context( - get_train_context_automodel(False, False, context_parallel_ctx)() - ) - if autocast_enabled: - stack.enter_context(torch.autocast(device_type="cuda", dtype=dtype)) - yield +STRING_TO_DTYPE = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} -# Classes with @ray.remote can't be inherited from, so we split the implementation out. -# This is useful when using worker extension classes. -class DTensorPolicyWorkerV2Impl(AbstractPolicyWorker, ColocatablePolicyInterface): +@ray.remote( + runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") +) # pragma: no cover +class DTensorPolicyWorkerV2(AbstractPolicyWorker, ColocatablePolicyInterface): def __repr__(self) -> str: """Customizes the actor's prefix in the Ray logs. @@ -204,6 +121,8 @@ def __repr__(self) -> str: def __init__( self, config: PolicyConfig, + tokenizer: AutoTokenizer, + processor: Optional[AutoProcessor] = None, weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, init_optimizer: bool = True, @@ -213,116 +132,484 @@ def __init__( """Initialize the DTensorPolicyWorkerV2.""" # Apply TE patch until TE is upgraded to 2.10.0 apply_transformer_engine_patch() + # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' + apply_torch_aten_alias_tensor_patch() + + self.tokenizer = tokenizer + self.processor = processor + self.is_vlm = processor is not None + + print(f"Initializing DTensorPolicyWorkerV2 with is_vlm={self.is_vlm}") + + self.is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. + # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. + if not self.is_generation_colocated: + os.environ["NCCL_CUMEM_ENABLE"] = "1" + + # Disable dynamo autotune_local_cache to avoid crash when there's already a cache + # with different order of node_bundles + configure_dynamo_cache() - # Store configuration self.cfg = config + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] + # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call + backend = "nccl" if not self.cpu_offload else "cuda:nccl,cpu:gloo" + torch.distributed.init_process_group(backend=backend) + self.rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + model_name = self.cfg["model_name"] - # Reconstruct tokenizer/processor locally to avoid pickling across - # incompatible transformers versions (v4 head node → v5 worker). - from nemo_rl.models.automodel.setup import get_tokenizer + self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None - use_processor = config["tokenizer"].get("use_processor", False) - result = get_tokenizer(config["tokenizer"], get_processor=use_processor) - if use_processor: - self.processor = result - self.tokenizer = result.tokenizer - else: - self.tokenizer = result - self.processor = None - self.is_vlm = self.processor is not None - self.lora_enabled = ( - config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False) + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] + self.offload_optimizer_for_logprob = self.cfg.get("offload_optimizer_for_logprob", False) + self.max_grad_norm = self.cfg["max_grad_norm"] + + try: + self.dtype = STRING_TO_DTYPE[self.cfg["precision"]] + except KeyError: + raise ValueError(f"Unknown precision: {self.cfg['precision']}") + + self.enable_seq_packing = self.cfg.get("sequence_packing", {}).get("enabled", False) + if self.enable_seq_packing: + assert not self.is_vlm, ( + "Sequence packing is not supported for VLM models. Please set policy.sequence_packing.enabled = False to train VLM models." + ) + print( + f"[Rank {self.rank}] Sequence packing is enabled for model {model_name}" + ) + print(f"[Rank {self.rank}] Using FlashAttention2 for sequence packing") + + hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {} + + # Choose attention implementation on the following basis: + # - Packed sequence requires FA2 and CP must be 1 + # - CP > 1 requires SDPA + cp_size_cfg = self.cfg["dtensor_cfg"]["context_parallel_size"] + + # NeMoAutoModelForCausalLM uses flash_attention_2 by default + # so we need to set it to None if sequence packing is disabled + # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 + attn_impl = ( + "flash_attention_2" + if (self.enable_seq_packing and cp_size_cfg == 1) + else ("sdpa" if cp_size_cfg > 1 else None) ) - print(f"Initializing DTensorPolicyWorkerV2 with is_vlm={self.is_vlm}") + model_config = AutoConfig.from_pretrained( + model_name, + # Always load the model in float32 to keep master weights in float32. + # Keeping the master weights in lower precision has shown to cause issues with convergence. + torch_dtype=torch.float32, + trust_remote_code=True, + attn_implementation="flash_attention_2" + if self.enable_seq_packing + else None, + **hf_config_overrides, + ) - # Initialize checkpoint manager - self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + self.allow_flash_attn_args = self.check_model_allow_flash_attn_args( + model_config + ) - # Validate configuration and prepare runtime settings - runtime_config = validate_and_prepare_config( - config=config, - processor=self.processor, - rank=0, # Temporary, will be updated after distributed init + self._is_reward_model = ( + "reward_model_cfg" in self.cfg and self.cfg["reward_model_cfg"]["enabled"] ) + if self._is_reward_model: + # Ensure sequence packing is disabled. + if self.enable_seq_packing: + raise NotImplementedError( + "Sequence packing is not supported for reward models" + ) + # Load model as a Reward Model. + rm_type = self.cfg["reward_model_cfg"]["reward_model_type"] + if rm_type == "bradley_terry": + model_class = NeMoAutoModelForSequenceClassification + if model_config.num_labels != 1: + # For Bradley-Terry reward models, the linear head has a single output. + # In the transformers library, the default setting for model_config.num_labels is 2 + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/configuration_utils.py#L259). + # Since num_labels is used as the out_features for the linear head + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/modeling_llama.py#L738) + # if num_labels is not 1, we set it to 1. This change may trigger a warning that some weights are not initialized + # from the model checkpoint and are instead initialized using model_config.initializer_range + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/configuration_llama.py#L62). + print( + "model_config.num_labels is not 1. Setting it to 1 since this value is used as the out_features " + "for the linear head of Bradley-Terry reward models." + ) + model_config.num_labels = 1 + else: + raise ValueError(f"Unknown reward model type: {rm_type}") + else: + # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc. + model_class = resolve_model_class(model_config.model_type) + + # lora config + lora_cfg = self.cfg["dtensor_cfg"].get("lora_cfg", None) + self.peft_config = None + self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] + if self.lora_enabled: + if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1: + assert not lora_cfg["use_triton"], ( + "Triton is not supported when tensor_parallel_size > 1" + ) + # Always use float32 since FSDP requires all parameters to be in the same dtype. + # autocast should cast the weights to the correct dtype during the forward pass. + cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} + self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) + + print(f"[Rank {self.rank}] Initializing empty model for FSDP...") + # All ranks initialize model on meta device, so FSDP can shard it. + # The actual weights will be broadcast from rank 0. + + cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + automodel_kwargs = self.cfg["dtensor_cfg"].get("automodel_kwargs", {}) + if automodel_kwargs.get("backend", None) is not None: + backend_class = _resolve_target( + automodel_kwargs.get("backend", None)["_target_"] + ) + backend_kwargs = automodel_kwargs.get("backend") + backend_kwargs.pop("_target_") + backend = backend_class( + **backend_kwargs, + ) + automodel_kwargs["backend"] = backend + + if "use_liger_kernel" not in automodel_kwargs: + automodel_kwargs["use_liger_kernel"] = False + + with init_empty_weights(): + from torch.nn.attention import SDPBackend + + if cp_size > 1: + # Match Automodel's `get_train_context` in `cp_utils.py` where only + # flash and efficient backends are supported + # Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57 + sdpa_method = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] + elif self.cfg["dtensor_cfg"]["activation_checkpointing"]: + # For activation checkpointing, we must disable the cudnn SDPA backend because + # it may not be selected during recomputation. + # In that case, we will get the following error: + # "Recomputed values have different metadata than during forward pass." + sdpa_method = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + else: + sdpa_method = None + + self.model = model_class.from_pretrained( + model_name, + attn_implementation=attn_impl, + torch_dtype=str(model_config.torch_dtype), + trust_remote_code=True, + config=model_config, + sdpa_method=sdpa_method, + **automodel_kwargs, + ) + if self.lora_enabled: + apply_lora_to_linear_modules(self.model, self.peft_config) + + # For activation checkpointing, we also must globally disable the cudnn SDPA backend + # to ensure that cudnn does not get selected during recomputation. + if self.cfg["dtensor_cfg"]["activation_checkpointing"]: + from torch.backends import cuda - # Set up distributed environment (returns DistributedContext) - distributed_context = setup_distributed( - config=config, - runtime_config=runtime_config, + cuda.enable_cudnn_sdp(False) + + # Hold a copy of model state_dict keys before any parallelization + self.model_state_dict_keys = list(self.model.state_dict().keys()) + + if self.model.config.pad_token_id is None: + self.model.config.pad_token_id = tokenizer.pad_token_id + + tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] + ep_size = self.cfg["dtensor_cfg"].get("expert_parallel_size", 1) + dp_size = None # will be inferred + if cp_size > 1 and self.enable_seq_packing: + raise ValueError( + "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) + sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"] + + if sequence_parallel_enabled and tp_size == 1: + print( + "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism." + ) + + if cp_size > 1: + assert not isinstance(self.model, Gemma3ForCausalLM), ( + "Context parallel is not supported for Gemma3ForCausalLM. Torch context parallel has many limitations. " + "Please refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) + + assert not (tp_size > 1 and sequence_parallel_enabled), ( + "It's a known issue that context parallel can't be used together with sequence parallel in DTensor worker. " + "Please either set cp_size = 1 or disable sequence parallel. " + "See https://github.com/NVIDIA-NeMo/RL/issues/659 for more details." + ) + + assert not self.is_vlm, ( + "Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models." + ) + + # ------------------------------------------------ + # Build device mesh and parallelize + # ------------------------------------------------ + manager = FSDP2Manager( + dp_size=dp_size, + dp_replicate_size=1, + tp_size=tp_size, + cp_size=cp_size, + ep_size=ep_size, + pp_size=1, + sequence_parallel=sequence_parallel_enabled, + use_hf_tp_plan=self.cfg["dtensor_cfg"].get("use_hf_tp_plan", False), + mp_policy=MixedPrecisionPolicy( + param_dtype=self.dtype, + reduce_dtype=torch.float32, + output_dtype=torch.float32, + ), + offload_policy=CPUOffloadPolicy(pin_memory=False) + if self.cpu_offload + else None, + backend="nccl", + world_size=world_size, + activation_checkpointing=self.cfg["dtensor_cfg"][ + "activation_checkpointing" + ], + custom_tp_plan=self.cfg["dtensor_cfg"].get("custom_parallel_plan", None), ) - # Set instance attributes from distributed context - self.rank = torch.distributed.get_rank() - self.device_mesh = distributed_context.device_mesh + + # Force setup distributed for world size 1 as FSDP2Manager skips it. + if world_size == 1: + manager._setup_distributed() + + # Store mesh references for downstream usage + self.device_mesh = manager.device_mesh self.dp_cp_mesh = self.device_mesh["dp_cp"] self.dp_mesh = self.device_mesh["dp"] self.tp_mesh = self.device_mesh["tp"] self.cp_mesh = self.device_mesh["cp"] - self.moe_mesh = distributed_context.moe_mesh - self.dp_size = distributed_context.dp_size - self.tp_size = distributed_context.tp_size - self.cp_size = distributed_context.cp_size + self.moe_mesh = getattr(manager, "moe_mesh", None) - # Initialize checkpoint manager now that distributed is set up + self.dp_size = manager.dp_size + self.tp_size = manager.tp_size + self.cp_size = manager.cp_size + + # Parallelize model + is_moe_model = any(["expert" in key for key in self.model_state_dict_keys]) + is_hf_model = ( + model_config.architectures[0] not in ModelRegistry.model_arch_name_to_cls + ) + if ( + not isinstance(self.model, PreTrainedModel) + and is_moe_model + and not is_hf_model + ): + assert self.tp_size == 1, ( + "Using custom implementation {self.model.__class__.__name__} for MoE model {model_name} which doesn't support tp_size > 1. Please use expert_parallel_size > 1 for custom implementation or set force_hf=True in your config at policy->dtensor_cfg->automodel_kwargs to use the HuggingFace implementation." + ) + assert self.cp_size == 1, ( + "Using custom implementation {self.model.__class__.__name__} for MoE model {model_name} which doesn't support cp_size > 1. Please set force_hf=True in your config at policy->dtensor_cfg->automodel_kwargs to use the HuggingFace implementation." + ) + moe_parallelize_model( + model=self.model, + world_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + pp_enabled=False, + dp_axis_names=( + ("dp_replicate", "dp_shard_cp") + if "dp_replicate" in self.device_mesh.mesh_dim_names + and "dp_shard_cp" in self.device_mesh.mesh_dim_names + else ("dp_shard_cp",) + ), + cp_axis_name="cp", + tp_axis_name="tp", + ep_axis_name="ep", + ep_shard_axis_names=("ep_shard",), + ) + else: + self.model = manager.parallelize(self.model) + + # Load base model weights across all ranks using Automodel Checkpointer + # This mirrors build_model_and_optimizer's is_meta_device + load_weights path + print(self.model) self._init_checkpoint_manager( config_updates={ - "model_repo_id": config["model_name"], - "dequantize_base_checkpoint": config.get( + "model_repo_id": model_name, + "dequantize_base_checkpoint": self.cfg.get( "dequantize_base_checkpoint", False ), "is_peft": self.lora_enabled, - "is_async": True, }, ) + self.checkpoint_manager.set_model_state_dict_keys(self.model_state_dict_keys) - # Set up model and optimizer - model_and_optimizer_state = setup_model_and_optimizer( - config=config, - tokenizer=self.tokenizer, - runtime_config=runtime_config, - distributed_context=distributed_context, - checkpoint_manager=self.checkpoint_manager, - is_vlm=self.is_vlm, - init_optimizer=init_optimizer, - weights_path=weights_path, - optimizer_path=optimizer_path, + # Load base HF weights unless an explicit checkpoint is provided later + # This puts shards directly into the parallelized model + self.checkpoint_manager.load_base_model( + self.model, + model_name=model_name, + hf_cache_dir=hf_config_overrides.get("cache_dir", None), + dequantize_base_checkpoint=self.cfg.get( + "dequantize_base_checkpoint", False + ), + peft_init_method=self.peft_config.lora_A_init + if self.peft_config is not None + else None, ) - # Set instance attributes from model and optimizer state (tuple unpacking) - ( - self.model, - self.optimizer, - self.scheduler, - self.is_hf_model, - self.is_moe_model, - self._is_reward_model, # Note: using underscore prefix for internal naming - self.model_class, - self.model_config, - self.peft_config, - self.autocast_enabled, - ) = model_and_optimizer_state - - # Initialize reference model if requested - self.reference_model_state_dict = None + # Handle tied word embeddings after loading the state dict + # We need to actually tie the parameters at the model level + is_tied_lm_head = hasattr(self.model, "lm_head") and getattr( + getattr(self.model, "config", {}), "tie_word_embeddings", False + ) + if is_tied_lm_head: + embed_tokens_weight = None + for name, param in self.model.named_parameters(): + if "embed_tokens" in name and name.endswith(".weight"): + embed_tokens_weight = param + break + + if embed_tokens_weight is not None: + self.model.lm_head.weight = embed_tokens_weight + print( + f"[Rank {self.rank}] lm_head weight tied: " + f"same object = {self.model.lm_head.weight is embed_tokens_weight}, " + f"embed norm = {embed_tokens_weight.data.float().norm().item():.4f}, " + f"lm_head norm = {self.model.lm_head.weight.data.float().norm().item():.4f}" + ) + else: + print(f"[Rank {self.rank}] WARNING: embed_tokens weight not found, lm_head NOT tied") + else: + print( + f"[Rank {self.rank}] lm_head tying skipped: " + f"has_lm_head={hasattr(self.model, 'lm_head')}, " + f"tie_word_embeddings={getattr(getattr(self.model, 'config', {}), 'tie_word_embeddings', 'MISSING')}" + ) + + if self.cpu_offload: + self.model = self.move_to_device(self.model, "cpu") + if init_reference_model: - self.reference_model_state_dict = setup_reference_model_state(self.model) - - # Set instance attributes from runtime config (tuple unpacking) - ( - self.model_class, # Already set above, but includes in tuple for completeness - self.model_config, # Already set above, but includes in tuple for completeness - self.hf_config_overrides, - self.allow_flash_attn_args, - self.attn_impl, - self.dtype, - self.enable_seq_packing, - self.max_grad_norm, - self.cpu_offload, - self.offload_optimizer_for_logprob, - self.is_generation_colocated, - self.sampling_params, - _runtime_is_reward_model, # Duplicate, already set as _is_reward_model - ) = runtime_config + self.reference_model_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) + + if init_optimizer: + optimizer_cls = get_class(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), + **self.cfg["optimizer"]["kwargs"], + ) + else: + self.optimizer = None + + if "scheduler" in self.cfg and self.optimizer is not None: + if isinstance(self.cfg["scheduler"], dict): + scheduler_cls = get_class(cast(str, self.cfg["scheduler"]["name"])) + self.scheduler = scheduler_cls( + self.optimizer, **self.cfg["scheduler"]["kwargs"] + ) + else: + schedulers = [] + for scheduler_cfg in self.cfg["scheduler"]: + if "name" in scheduler_cfg: + schedulers.append( + get_class(scheduler_cfg["name"])( + self.optimizer, **scheduler_cfg["kwargs"] + ) + ) + else: + assert "milestones" in scheduler_cfg, ( + "unknown scheduler config: ", + scheduler_cfg, + ) + milestones: list[int] = scheduler_cfg["milestones"] + + self.scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, schedulers, milestones + ) + + elif self.optimizer is not None: + ## default to a passthrough LR schedule + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: 1 + ) + + # restore + if weights_path: + print(f"Loading weights from {weights_path}") + self.load_checkpoint(weights_path, optimizer_path) + if self.rank == 0: + for name, param in self.model.named_parameters(): + _p = param.data.float() + if torch.isnan(_p).any() or torch.isinf(_p).any(): + print(f" [NaN debug rank-0] CORRUPTED param after checkpoint load: {name}, has_nan={torch.isnan(_p).any().item()}, has_inf={torch.isinf(_p).any().item()}") + break + else: + print( + "No weights path provided. Loaded base HF weights via Checkpointer (default policy init)" + ) + + def _apply_temperature_scaling(self, logits: torch.Tensor, skip: bool = False) -> torch.Tensor: + if skip: + return logits + if "generation" in self.cfg and self.cfg["generation"] is not None: + temp = self.cfg["generation"]["temperature"] + if temp > 0: + logits.div_(temp) + return logits + + def check_model_allow_flash_attn_args(self, model_config) -> bool: + # Some models doesn't support flash_attn_kwargs + # Check nemotron nas. + if ( + model_config.architectures[0] == "DeciLMForCausalLM" + and model_config.model_type == "nemotron-nas" + ): + return False + + return True + + def init_cross_tokenizer_loss_fn(self, loss_config, token_aligner_config): + """Build CrossTokenizerDistillationLossFn locally from config + shared filesystem.""" + from nemo_rl.algorithms.x_token import TokenAligner + from nemo_rl.algorithms.loss_functions import CrossTokenizerDistillationLossFn + + aligner = TokenAligner( + teacher_tokenizer_name=token_aligner_config["teacher_model"], + student_tokenizer_name=token_aligner_config["student_model"], + max_comb_len=token_aligner_config.get("max_comb_len", 4), + projection_matrix_multiplier=token_aligner_config.get("projection_matrix_multiplier", 1.0), + ) + aligner._load_logits_projection_map( + file_path=token_aligner_config["projection_matrix_path"], + use_sparse_format=token_aligner_config.get("use_sparse_format", True), + learnable=token_aligner_config.get("learnable", False), + device="cpu", + ) + if token_aligner_config.get("project_teacher_to_student", False): + aligner.create_reverse_projection_matrix(device="cpu") + self._cached_loss_fn = CrossTokenizerDistillationLossFn(loss_config, aligner) + + def update_cross_tokenizer_data(self, teacher_input_ids, aligned_pairs) -> None: + """Update per-step cross-tokenizer data on the cached loss function.""" + if hasattr(self, '_cached_loss_fn') and self._cached_loss_fn is not None: + self._cached_loss_fn.set_cross_tokenizer_data( + teacher_input_ids=teacher_input_ids, + aligned_pairs=aligned_pairs, + ) @wrap_with_nvtx_name("dtensor_policy_worker_v2/train") def train( @@ -332,8 +619,13 @@ def train( eval_mode: bool = False, gbs: Optional[int] = None, mbs: Optional[int] = None, + is_teacher: bool = False, + teacher_logits: Optional[Any] = None, + topk_logits: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" + if loss_fn is None: + loss_fn = getattr(self, '_cached_loss_fn', None) if gbs is None: gbs = self.cfg["train_global_batch_size"] if mbs is None: @@ -347,121 +639,439 @@ def train( ) num_global_batches = int(total_dataset_size.item()) // gbs - # Validate sequence dimension - sequence_dim, _ = check_sequence_dim(data) + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) - if eval_mode: + if is_teacher: ctx: AbstractContextManager[Any] = torch.no_grad() self.model.eval() + elif eval_mode: + ctx = torch.no_grad() + self.model.eval() else: ctx = nullcontext() - # Ensure model is in training mode self.model.train() - # Create loss post-processor - loss_post_processor = LossPostProcessor( - loss_fn=loss_fn, - cfg=self.cfg, - device_mesh=self.device_mesh, - cp_mesh=self.cp_mesh, - tp_mesh=self.tp_mesh, - cp_size=self.cp_size, - dp_size=self.dp_size, - enable_seq_packing=self.enable_seq_packing, - sampling_params=self.sampling_params, - ) - - # Create train context factory - def train_context_fn(processed_inputs): - return get_train_context( - cp_size=self.cp_size, - cp_mesh=self.cp_mesh, - cp_buffers=processed_inputs.cp_buffers, - sequence_dim=sequence_dim, - dtype=self.dtype, - autocast_enabled=self.autocast_enabled, - ) - - # Setup cache clearing callback if configured - empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( - "clear_cache_every_n_steps" - ) - if empty_cache_steps: - warnings.warn( - f"Emptying cache every {empty_cache_steps} microbatches; doing so unnecessarily would incur a large performance overhead.", - ) - - def on_microbatch_start(mb_idx): - if empty_cache_steps and mb_idx % empty_cache_steps == 0: - torch.cuda.empty_cache() - with ctx: # Get data from batch and move to device - data = data.to("cuda") + data.to("cuda") losses = [] all_mb_metrics = [] for gb_idx in range(num_global_batches): - # Process global batch and compute normalization factors - gb_result = process_global_batch( - data, - loss_fn, - self.dp_mesh.get_group(), - batch_idx=gb_idx, - batch_size=local_gbs, - ) - batch = gb_result["batch"] - global_valid_seqs = gb_result["global_valid_seqs"] - global_valid_toks = gb_result["global_valid_toks"] - - self.optimizer.zero_grad() - - # Get microbatch iterator based on batching strategy - processed_iterator, iterator_len = get_microbatch_iterator( - batch, - self.cfg, - mbs, - self.dp_mesh, - tokenizer=self.tokenizer, - cp_size=self.cp_size, - ) + global_batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) - # Use automodel_forward_backward for the training loop - mb_results = automodel_forward_backward( - model=self.model, - data_iterator=processed_iterator, - post_processing_fn=loss_post_processor, - forward_only=eval_mode, - is_reward_model=self._is_reward_model, - allow_flash_attn_args=self.allow_flash_attn_args, - global_valid_seqs=global_valid_seqs, - global_valid_toks=global_valid_toks, - sampling_params=self.sampling_params, - sequence_dim=sequence_dim, - dp_size=self.dp_size, - cp_size=self.cp_size, - num_global_batches=num_global_batches, - train_context_fn=train_context_fn, - num_valid_microbatches=iterator_len, - on_microbatch_start=on_microbatch_start, + assert "sample_mask" in global_batch, ( + "sample_mask must be present in the data!" ) + ## get the normalization factor for the loss + local_valid_seqs = torch.sum(global_batch["sample_mask"]) + + if "token_mask" not in global_batch: + local_valid_toks = ( + local_valid_seqs * global_batch["input_ids"].shape[1] + ) + else: + local_valid_toks = torch.sum( + global_batch["token_mask"][:, 1:] + * global_batch["sample_mask"].unsqueeze(-1) + ) + + to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda() + torch.distributed.all_reduce(to_reduce, group=self.dp_mesh.get_group()) + global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1] + + if ( + hasattr(loss_fn, "loss_type") + and loss_fn.loss_type == LossType.TOKEN_LEVEL + ): + assert "token_mask" in global_batch, ( + "token_mask must be present in the data when using token-level loss" + ) - # Extract losses and metrics from results + if not is_teacher: + self.optimizer.zero_grad() mb_losses = [] - for mb_idx, (loss, loss_metrics) in enumerate(mb_results): - # Only process valid (non-dummy) batches for metrics - if mb_idx < iterator_len: - num_valid_samples = loss_metrics["num_valid_samples"] - loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] - loss_metrics["global_valid_seqs"] = global_valid_seqs.item() - loss_metrics["global_valid_toks"] = global_valid_toks.item() + batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + dummy_iterator = iter([]) + if self.cfg.get("dynamic_batching", {}).get("enabled", False): + mb_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = batch.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + iterator_len, max_seqlen = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = batch.make_microbatch_iterator(mbs) + iterator_len = batch.size // mbs + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "clear_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) + + _teacher_mb_handles = [] + _teacher_is_topk = False + if not is_teacher and teacher_logits is not None: + rank = torch.distributed.get_rank() + worker_result = teacher_logits[rank] + _teacher_mb_handles = worker_result['microbatch_handles'] + _teacher_is_topk = worker_result.get('is_topk', False) + + for mb_idx, mb in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() + + with torch.autocast(device_type="cuda", dtype=self.dtype): + if self.enable_seq_packing: + input_ids = mb.get("input_ids").cuda() + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=mb["input_lengths"], + packed_sequence_size=[ + len(mb["input_lengths"]) + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"][ + "train_mb_tokens" + ], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled. + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=mb["input_lengths"], + ) + + else: + input_ids = mb.get("input_ids").cuda() + batch_size, seq_len = input_ids.shape + + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} + + # add vlm kwargs to model call + vlm_kwargs = mb.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + if len(vlm_kwargs) > 0: + position_ids = None + assert not self.cfg["dtensor_cfg"]["sequence_parallel"], ( + "Sequence parallel is not supported with multimodal since there's an issue when you do not pass position_ids. See https://github.com/NVIDIA-NeMo/Automodel/issues/652" + ) + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + f"multimodal kwargs={vlm_kwargs} are not supported for context parallel" + ) + seq_index = torch.arange( + seq_len, device=input_ids.device + ).repeat(1, 1) + cp_buffers = ( + [input_ids, position_ids, seq_index] + if self.cp_size > 1 + else [] + ) + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + + if self._is_reward_model: + # `flash_attn_kwarg` is not supported for `LlamaForSequenceClassification`. + # Note that it should be empty anyway since sequence packing + # is not supported for reward models. + assert not flash_attn_kwargs + del model_args["flash_attn_kwargs"] + # remove flash_attn_kwargs if there are multimodal kwargs + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + if ( + not self.allow_flash_attn_args + and "flash_attn_kwargs" in model_args + ): + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + # Get logprobs + if isinstance(outputs, (torch.Tensor, DTensor)): + # custom models (e.g., those coming from AutoModel) can output logits directly + logits = outputs + elif not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + del outputs + + # Temperature scaling is only for inference/generation, not training + logits = self._apply_temperature_scaling(logits, skip=True) + + if self.cp_size > 1: + seq_index_dtensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + mb["seq_index"] = seq_index_dtensor + + for tensor_name in mb: + current_tensor = mb[tensor_name] + for buffer in cp_buffers: + if current_tensor is buffer: + assert type(current_tensor) == torch.Tensor, ( + f"tensor {tensor_name} is not a tensor" + ) + mb[tensor_name] = DTensor.from_local( + current_tensor, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + break + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + if is_teacher: + with torch.no_grad(): + if isinstance(logits, DTensor): + mb_logits_local = logits.to_local() + else: + mb_logits_local = logits + del logits + + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(mb_logits_local.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + mb_logits_local = mb_logits_local.to(torch.float32) + mb_log_prob = _compute_distributed_log_softmax(mb_logits_local, group=tp_group) + del mb_logits_local + + if isinstance(mb_log_prob, DTensor): + mb_log_prob = mb_log_prob.to_local() + + if self.cfg.get('is_mdlm', False): + shared_seq_len = int(mb_log_prob.shape[1] / 2) + mb_log_prob = mb_log_prob[:, shared_seq_len:, :] + + if topk_logits is not None: + mb_topk_vals, mb_topk_idx = distributed_vocab_topk( + mb_log_prob, + k=topk_logits, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + del mb_log_prob + + B_mb, S_mb, K_mb = mb_topk_vals.shape + buf_idx = len(_teacher_mb_handles) + self._ensure_teacher_mb_topk_buffer( + buf_idx, B_mb, K_mb, + mb_topk_vals.dtype, mb_topk_idx.dtype, mb_topk_vals.device, + ) + self._teacher_mb_vals_buffers[buf_idx][:B_mb, :S_mb, :K_mb].copy_(mb_topk_vals) + self._teacher_mb_idx_buffers[buf_idx][:B_mb, :S_mb, :K_mb].copy_(mb_topk_idx) + del mb_topk_vals, mb_topk_idx + + rank = torch.distributed.get_rank() + _teacher_mb_handles.append({ + rank: self._teacher_mb_vals_ipcs[buf_idx], + 'actual_shape': (B_mb, S_mb, K_mb), + 'topk_indices_ipc': self._teacher_mb_idx_ipcs[buf_idx], + }) + else: + B_mb, S_mb, V_mb = mb_log_prob.shape + buf_idx = len(_teacher_mb_handles) + self._ensure_teacher_mb_logits_buffer( + buf_idx, B_mb, V_mb, + mb_log_prob.dtype, mb_log_prob.device, + ) + self._teacher_mb_logits_buffers[buf_idx][:B_mb, :S_mb, :V_mb].copy_(mb_log_prob) + del mb_log_prob + + rank = torch.distributed.get_rank() + _teacher_mb_handles.append({ + rank: self._teacher_mb_logits_ipcs[buf_idx], + 'actual_shape': (B_mb, S_mb, V_mb), + }) + else: + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = loss_fn + + # ── NaN debug: inspect logits on first microbatch of first 2 steps ── + if mb_idx == 0 and gb_idx == 0 and self.rank == 0 and len(losses) < 2: + _local_logits = logits.to_local() if isinstance(logits, DTensor) else logits + _lf = _local_logits.float() + print( + f" [NaN debug rank-0] logits shape={_local_logits.shape}, " + f"dtype={_local_logits.dtype}, " + f"min={_lf.min().item():.4f}, max={_lf.max().item():.4f}, " + f"has_nan={torch.isnan(_lf).any().item()}, " + f"has_inf={torch.isinf(_lf).any().item()}, " + f"global_valid_toks={global_valid_toks.item():.0f}, " + f"global_valid_seqs={global_valid_seqs.item():.0f}", + flush=True, + ) + del _local_logits, _lf + + if _teacher_mb_handles and mb_idx < len(_teacher_mb_handles) and not self.enable_seq_packing: + rank = torch.distributed.get_rank() + current_device_id = torch.cuda.current_device() + handle = _teacher_mb_handles[mb_idx] + aB, aS, aK = handle['actual_shape'] + + teacher_logits_tensor = rebuild_cuda_tensor_from_ipc( + handle[rank], current_device_id + ).detach() + teacher_logits_tensor = teacher_logits_tensor[:aB, :aS, :aK].clone() + + teacher_topk_indices_tensor = None + if _teacher_is_topk and 'topk_indices_ipc' in handle: + teacher_topk_indices_tensor = rebuild_cuda_tensor_from_ipc( + handle['topk_indices_ipc'], current_device_id + ).detach() + teacher_topk_indices_tensor = teacher_topk_indices_tensor[:aB, :aS, :aK].clone() + + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + teacher_logits=teacher_logits_tensor, + teacher_topk_indices_ipc=teacher_topk_indices_tensor, + ) + del teacher_logits_tensor + if teacher_topk_indices_tensor is not None: + del teacher_topk_indices_tensor + else: + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) + del logits + + # skip the update for dummy batches + if mb_idx < iterator_len: + ## scale by the number of global batches so we get the correct + ## value when summing metrics across all microbatches + for k in loss_metrics.keys(): + loss_metrics[k] /= num_global_batches + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["global_valid_seqs"] = global_valid_seqs.item() + loss_metrics["global_valid_toks"] = global_valid_toks.item() + else: + loss *= 0 + + # Backward pass + if not eval_mode: + loss *= self.dp_size * self.cp_size + loss.backward() + + if not is_teacher: if num_valid_samples > 0: mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) + if is_teacher: + self.teacher_logits = { + 'microbatch_handles': _teacher_mb_handles, + 'is_topk': topk_logits is not None, + } + return self.teacher_logits + grad_norm: Optional[float | torch.Tensor] = None - if not eval_mode: + if not is_teacher and not eval_mode: grad_norm = scale_grads_and_clip_grad_norm( self.max_grad_norm, [self.model], @@ -485,28 +1095,39 @@ def on_microbatch_start(mb_idx): # Update parameters self.optimizer.step() - losses.append(torch.tensor(mb_losses).sum().item()) + if not is_teacher: + losses.append(torch.tensor(mb_losses).sum().item()) # release gradient memory before rollouts self.optimizer.zero_grad() - # increment scheduler after all batches in rollout are processed if not eval_mode: self.scheduler.step() - # dynamic batch and sequence dims causes alot of fragmentation, so clear - # the memory allocator before moving on torch.cuda.empty_cache() - # Aggregate training statistics across microbatches and ranks - metrics = aggregate_training_statistics( - losses=losses, - all_mb_metrics=all_mb_metrics, - grad_norm=grad_norm, - dp_group=self.dp_mesh.get_group(), - dtype=self.dtype, - ) + # Compute global loss across all ranks + with torch.no_grad(): + global_loss = torch.tensor(losses, device="cuda") + torch.distributed.all_reduce( + global_loss, group=self.dp_mesh.get_group() + ) + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "grad_norm": grad_norm, + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": dict(mb_metrics), + } return metrics + # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_logprobs") def get_logprobs( self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None @@ -528,62 +1149,274 @@ def get_logprobs( if micro_batch_size is not None else self.cfg["logprob_batch_size"] ) - - # Validate sequence dimension - sequence_dim, seq_dim_size = check_sequence_dim(data) + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) + + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) all_log_probs = [] self.model.eval() - # Create logprobs post-processor - logprobs_post_processor = LogprobsPostProcessor( - cfg=self.cfg, - device_mesh=self.device_mesh, - cp_mesh=self.cp_mesh, - tp_mesh=self.tp_mesh, - cp_size=self.cp_size, - enable_seq_packing=self.enable_seq_packing, - sampling_params=self.sampling_params, - ) - with torch.no_grad(): data.to("cuda") - # Get microbatch iterator based on batching strategy - processed_iterator, iterator_len = get_microbatch_iterator( - data, - self.cfg, - logprob_batch_size, - self.dp_mesh, - tokenizer=self.tokenizer, - cp_size=self.cp_size, - ) + dummy_iterator = iter([]) + if self.cfg.get("dynamic_batching", {}).get("enabled", False): + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) - for batch_idx, processed_mb in enumerate(processed_iterator): - processed_inputs = processed_mb.processed_inputs + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(logprob_batch_size) + iterator_len = data.size // logprob_batch_size + + step = 0 + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + vlm_kwargs = lp_batch.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) - with get_train_context( - cp_size=self.cp_size, - cp_mesh=self.cp_mesh, - cp_buffers=processed_inputs.cp_buffers, - sequence_dim=sequence_dim, - dtype=self.dtype, - autocast_enabled=self.autocast_enabled, - ): - # Use forward_with_post_processing_fn for forward pass and post-processing - token_logprobs, _metrics, _ = forward_with_post_processing_fn( - model=self.model, - post_processing_fn=logprobs_post_processor, - processed_mb=processed_mb, - is_reward_model=False, - allow_flash_attn_args=self.allow_flash_attn_args, - sampling_params=self.sampling_params, - sequence_dim=sequence_dim, + batch_size, seq_len = input_ids.shape + if self.enable_seq_packing: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for sequence packing" + ) + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create post_attention_mask for right-padded data for masking token after forwarding. + post_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.bool, device=input_ids.device ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + post_attention_mask[i, :length] = 1 + + # explicitly create position ids for the input, otherwise the sharding + # for DTensor will be incorrect + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} + + # DTensor requires the casual attention kernel to hit, + # yet our attention mask above is not always all 1s + # this is fine because we mask with the actual attention mask + # later, but for input it has to be all 1s + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + + # if there are multimodal kwargs, we don't need to add position_ids (computed internally) + if len(vlm_kwargs) > 0: + position_ids = None + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for context parallel" + ) + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + if ( + not self.allow_flash_attn_args + and "flash_attn_kwargs" in model_args + ): + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + logits = outputs.logits if hasattr(outputs, "logits") else outputs + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + if self.cp_size > 1: + seq_index_tensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + input_ids_dtensor = DTensor.from_local( + input_ids, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits, + input_ids_dtensor, + seq_index_tensor, + chunk_size=logprob_chunk_size, + ) + + assert token_logprobs.shape[1] == seq_len - 1 + else: + if isinstance(logits, DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits, + input_ids, + chunk_size=logprob_chunk_size, + ) + else: + if logprob_chunk_size is not None: + logits_seq_len = int(logits.shape[1]) + num_chunks = ( + logits_seq_len + logprob_chunk_size - 1 + ) // logprob_chunk_size + chunked_log_probs = [] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * logprob_chunk_size + chunk_end = min( + logits_seq_len, + (chunk_idx + 1) * logprob_chunk_size, + ) + chunk_logits = logits[ + :, chunk_start:chunk_end, : + ].to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + chunk_logits, dim=-1 + ) + chunked_log_probs.append(log_probs) + log_probs = torch.cat(chunked_log_probs, dim=1) + del chunked_log_probs + else: + logits = logits.to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + logits, dim=-1 + ) + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + del log_probs + + del outputs, logits + + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) # skip keeping the logprobs for the dummy batches if batch_idx >= iterator_len: continue + if not self.enable_seq_packing: + # Apply mask to zero out padding tokens logprobs + token_logprobs = token_logprobs * post_attention_mask + else: + # For packed sequences, unpack logprobs + unpacked_logprobs = torch.zeros( + (batch_size, seq_dim_size), + dtype=token_logprobs.dtype, + device=token_logprobs.device, + ) + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + for i in range(batch_size): + start = cu_seqlens[i].item() + 1 + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + unpacked_logprobs[i, 1:seq_len_actual] = token_logprobs[ + 0, start:end + ] + token_logprobs = unpacked_logprobs + all_log_probs.append(token_logprobs) # Concatenate all batches @@ -601,58 +1434,121 @@ def get_logprobs( return return_data + # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) @wrap_with_nvtx_name("dtensor_policy_worker_v2/score") def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: global_batch_size = min(self.cfg["batch_size"], data.size) - # Validate sequence dimension - sequence_dim, _ = check_sequence_dim(data) - + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) self.model.eval() print("Begin to batch datas") - - # Create score post-processor - score_post_processor = ScorePostProcessor(cfg=self.cfg) - with torch.no_grad(): data.to("cuda") - # Get microbatch iterator based on batching strategy - processed_iterator, iterator_len = get_microbatch_iterator( - data, - self.cfg, - global_batch_size, - self.dp_mesh, - tokenizer=self.tokenizer, - cp_size=self.cp_size, - ) - + dummy_iterator = iter([]) + if self.cfg.get("dynamic_batching", {}).get("enabled", False): + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(global_batch_size) + iterator_len = data.size // global_batch_size + step = 0 all_rm_scores = [] - for batch_idx, processed_mb in enumerate(processed_iterator): - processed_inputs = processed_mb.processed_inputs - - with get_train_context( - cp_size=self.cp_size, - cp_mesh=self.cp_mesh, - cp_buffers=processed_inputs.cp_buffers, - sequence_dim=sequence_dim, - dtype=self.dtype, - autocast_enabled=self.autocast_enabled, - ): - # Use forward_with_post_processing_fn for forward pass and post-processing - rm_scores, _metrics, _ = forward_with_post_processing_fn( - model=self.model, - post_processing_fn=score_post_processor, - processed_mb=processed_mb, - is_reward_model=True, - allow_flash_attn_args=False, - sampling_params=self.sampling_params, - sequence_dim=sequence_dim, + for batch_idx, generate_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 + input_ids = generate_batch.get("input_ids").cuda() + input_lengths = generate_batch.get("input_lengths") + batch_size, seq_len = input_ids.shape + if self.enable_seq_packing: + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create attention mask for right-padded data + post_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.bool, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + post_attention_mask[i, :length] = 1 + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ) + outputs = self.model(**model_args) + + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + if isinstance(logits, DTensor): + logits = logits.to(torch.float32) + else: + logits = outputs.logits.to(torch.float32) - # skip keeping the scores for the dummy batches - if batch_idx >= iterator_len: - continue - + rm_scores = to_local_if_dtensor(logits) + rm_scores = rm_scores.squeeze(-1) all_rm_scores.append(rm_scores) all_rm_scores = torch.cat(all_rm_scores, dim=0) @@ -687,76 +1583,282 @@ def get_topk_logits( else self.cfg["logprob_batch_size"] ) - # Validate sequence dimension - sequence_dim, seq_dim_size = check_sequence_dim(data) + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] out_topk_vals = [] out_topk_idx = [] self.model.eval() - # Create top-k post-processor - topk_post_processor = TopkLogitsPostProcessor( - cfg=self.cfg, - device_mesh=self.device_mesh, - cp_mesh=self.cp_mesh, - tp_mesh=self.tp_mesh, - cp_size=self.cp_size, - k=k, - enable_seq_packing=self.enable_seq_packing, - ) - with torch.no_grad(): data.to("cuda") - # Get microbatch iterator based on batching strategy - processed_iterator, iterator_len = get_microbatch_iterator( - data, - self.cfg, - topk_batch_size, - self.dp_mesh, - tokenizer=self.tokenizer, - cp_size=self.cp_size, - ) + dummy_iterator = iter([]) + if self.cfg.get("dynamic_batching", {}).get("enabled", False): + # dynamic batching support (no CP/packed) + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) - for batch_idx, processed_mb in enumerate(processed_iterator): - processed_inputs = processed_mb.processed_inputs + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(topk_batch_size) + iterator_len = data.size // topk_batch_size + + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + + batch_size, seq_len = input_ids.shape + # Store original shapes for unpacking later + original_batch_size = batch_size + original_seq_len = seq_len + + if self.enable_seq_packing: + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Build attention mask (right-padded inputs) + attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 - with get_train_context( - cp_size=self.cp_size, - cp_mesh=self.cp_mesh, - cp_buffers=processed_inputs.cp_buffers, - sequence_dim=sequence_dim, - dtype=self.dtype, - autocast_enabled=self.autocast_enabled, - ): - # Use forward_with_post_processing_fn for forward pass and post-processing - (vals, idx), _metrics, _ = forward_with_post_processing_fn( - model=self.model, - post_processing_fn=topk_post_processor, - processed_mb=processed_mb, - is_reward_model=False, - allow_flash_attn_args=self.allow_flash_attn_args, - sampling_params=self.sampling_params, - sequence_dim=sequence_dim, + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + flash_attn_kwargs = {} + + with torch.autocast(device_type="cuda", dtype=self.dtype): + attention_mask_input_all_ones = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) - # skip keeping the topk values for the dummy batches - if batch_idx >= iterator_len: - continue + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask_input_all_ones, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + ) + + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + del outputs + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + if self.cp_size > 1: + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + # deal with TP first + local_logits = logits.to_local() # [B, S_cp, V_tp] + + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + vals, idx = distributed_vocab_topk( + local_logits, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + # [B, S_cp, k] + + cp_group = self.cp_mesh.get_group() + + vals = allgather_cp_sharded_tensor( + vals, cp_group, seq_dim=sequence_dim + ) + idx = allgather_cp_sharded_tensor( + idx, cp_group, seq_dim=sequence_dim + ) + # [B, S, k] + else: + # Compute top-k over full sequence length (do not drop last position) + if isinstance(logits, DTensor): + local_logits = logits.to_local() # [B, S, V_local] + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + local_logits = local_logits.to(torch.float32) + local_log_probs = _compute_distributed_log_softmax(local_logits, group=tp_group) + del logits, local_logits + + if isinstance(local_log_probs, DTensor): + local_log_probs = local_log_probs.to_local() + + if self.cfg.get('is_mdlm', False): + shared_sequence_length = int(local_log_probs.shape[1] // 2) + local_log_probs = local_log_probs[:, shared_sequence_length:, :] + + vals, idx = distributed_vocab_topk( + local_log_probs, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + del local_log_probs + else: + full_logits = logits.to(torch.float32) + vals, idx = torch.topk(full_logits, k=k, dim=-1) + + # Handle sequence packing unpacking + if self.enable_seq_packing: + # Unpack top-k results from packed format back to original batch format + # vals: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] + # idx: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] + + # Create tensors to store unpacked results + unpacked_vals = torch.zeros( + (original_batch_size, original_seq_len, k), + dtype=vals.dtype, + device=vals.device, + ) + unpacked_idx = torch.zeros( + (original_batch_size, original_seq_len, k), + dtype=idx.dtype, + device=idx.device, + ) + + # Get cumulative sequence lengths for unpacking + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + + for i in range(original_batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + + # Extract the corresponding portion from packed results + # Note: vals and idx are [1, packed_seq_len, k] due to packing + unpacked_vals[i, :seq_len_actual, :] = vals[0, start:end, :] + unpacked_idx[i, :seq_len_actual, :] = idx[0, start:end, :] + + # Replace with unpacked results + vals = unpacked_vals + idx = unpacked_idx + + # Update batch_size and seq_len for consistency + batch_size = original_batch_size + seq_len = original_seq_len - # Keep only real sequence tokens (no trimming here; padded positions can be masked downstream) # Shapes remain [B, S, k]. - out_topk_vals.append(vals.cpu()) - out_topk_idx.append(idx.cpu()) + B_mb, S_mb, K_mb = vals.shape + target_dtype = vals.dtype + target_device = vals.device + + # Pre-allocate two IPC buffers (values + indices) exactly once. + if not hasattr(self, '_teacher_topk_vals_buffer') or self._teacher_topk_vals_buffer is None: + max_S = self.cfg.get("max_total_sequence_length", S_mb) + vals_buf_shape = (B_mb, max_S, K_mb) + self._teacher_topk_vals_buffer = torch.empty( + vals_buf_shape, dtype=target_dtype, device=target_device + ) + self._teacher_topk_vals_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_vals_buffer) + } + idx_buf_shape = (B_mb, max_S, K_mb) + self._teacher_topk_idx_buffer = torch.empty( + idx_buf_shape, dtype=idx.dtype, device=target_device + ) + self._teacher_topk_idx_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_idx_buffer) + } + print(f" rank {torch.distributed.get_rank()} Allocated topk IPC buffers: " + f"vals={vals_buf_shape} ({self._teacher_topk_vals_buffer.numel() * self._teacher_topk_vals_buffer.element_size() / 1e9:.4f} GB), " + f"idx={idx_buf_shape} ({self._teacher_topk_idx_buffer.numel() * self._teacher_topk_idx_buffer.element_size() / 1e9:.4f} GB) " + f"(actual data: [{B_mb}, {S_mb}, {K_mb}])") + + # Copy actual data into the top-left slice of the buffers + self._teacher_topk_vals_buffer[:B_mb, :S_mb, :K_mb].copy_(vals) + self._teacher_topk_idx_buffer[:B_mb, :S_mb, :K_mb].copy_(idx) + del vals, idx + + out_topk_vals.append(self._teacher_topk_vals_buffer[:B_mb, :S_mb, :K_mb].cpu()) + out_topk_idx.append(self._teacher_topk_idx_buffer[:B_mb, :S_mb, :K_mb].cpu()) ret = BatchedDataDict[Any]() - # Pad each micro-batch result on sequence dim to common length (S), similar to get_logprobs all_topk_vals_padded = [] all_topk_idx_padded = [] target_seq_len = seq_dim_size for vals, idx in zip(out_topk_vals, out_topk_idx): pad_needed = target_seq_len - vals.shape[1] if pad_needed > 0: - # pad along sequence dimension (second dim): (last_dim_pad_left, last_dim_pad_right, seq_pad_left, seq_pad_right, batch_pad_left, batch_pad_right) vals = torch.nn.functional.pad( vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 ) @@ -778,52 +1880,61 @@ def get_topk_logits( ).cpu() return ret + def _ensure_teacher_mb_topk_buffer(self, buf_idx, B, K, vals_dtype, idx_dtype, device): + """Lazily grow the per-microbatch IPC buffer pool for top-k teacher logits.""" + if not hasattr(self, '_teacher_mb_vals_buffers'): + self._teacher_mb_vals_buffers = [] + self._teacher_mb_vals_ipcs = [] + self._teacher_mb_idx_buffers = [] + self._teacher_mb_idx_ipcs = [] + max_S = self.cfg.get("max_total_sequence_length", 1) + while len(self._teacher_mb_vals_buffers) <= buf_idx: + vals_buf = torch.empty((B, max_S, K), dtype=vals_dtype, device=device) + idx_buf = torch.empty((B, max_S, K), dtype=idx_dtype, device=device) + self._teacher_mb_vals_buffers.append(vals_buf) + self._teacher_mb_vals_ipcs.append(get_handle_from_tensor(vals_buf)) + self._teacher_mb_idx_buffers.append(idx_buf) + self._teacher_mb_idx_ipcs.append(get_handle_from_tensor(idx_buf)) + + def _ensure_teacher_mb_logits_buffer(self, buf_idx, B, V, dtype, device): + """Lazily grow the per-microbatch IPC buffer pool for full-vocab teacher logits.""" + if not hasattr(self, '_teacher_mb_logits_buffers'): + self._teacher_mb_logits_buffers = [] + self._teacher_mb_logits_ipcs = [] + max_S = self.cfg.get("max_total_sequence_length", 1) + while len(self._teacher_mb_logits_buffers) <= buf_idx: + buf = torch.empty((B, max_S, V), dtype=dtype, device=device) + self._teacher_mb_logits_buffers.append(buf) + self._teacher_mb_logits_ipcs.append(get_handle_from_tensor(buf)) + @contextmanager def use_reference_model(self) -> Generator[None, None, None]: """Context manager that temporarily swaps the reference model and active model. - On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references. - Also disables top-k/top-p filtering since the reference policy's distribution - is different from the current policy, making filtered logprobs incompatible. - On exit: Restores original references and re-flips cuda/cpu, restores sampling_params. + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references + On exit: Restores original references and re-flips cuda/cpu """ with torch.no_grad(): - # Save train model state_dict - curr_state_dict = get_cpu_state_dict( - self.model.state_dict().items(), pin_memory=True - ) - - # Swap reference model state_dict to self.model - for k, v in self.model.state_dict().items(): - val = to_local_if_dtensor(v) - val.copy_(self.reference_model_state_dict[k]) - - # Temporarily disable top-k/top-p filtering for reference policy logprobs. - # The reference policy has different weights, so its top-k/top-p set is - # inherently different from the current policy. Using filtered logprobs - # would cause -inf mismatches that cannot be resolved by masking. - # Note: We keep temperature scaling since it was applied to prev_logprobs. - saved_sampling_params = self.sampling_params - if saved_sampling_params is not None: - self.sampling_params = TrainingSamplingParams( - top_k=None, - top_p=1.0, - temperature=saved_sampling_params.temperature, + try: + # Save train model state_dict + curr_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True ) - else: - self.sampling_params = None - # - self.model is the original reference_model, now on CUDA - # - curr_state_dict is the train model, now on CPU - yield + # Swap reference model state_dict to self.model + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(self.reference_model_state_dict[k]) - # Restore sampling_params - self.sampling_params = saved_sampling_params + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU + yield - # Restore train model state_dict - for k, v in self.model.state_dict().items(): - val = to_local_if_dtensor(v) - val.copy_(curr_state_dict[k]) + finally: + # Restore train model state_dict + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(curr_state_dict[k]) def _add_noise_to_weights(self) -> None: """Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.""" @@ -850,17 +1961,8 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: """Prepare state dict metadata for weight refitting and IPC streaming.""" state_dict_info = {} for name, tensor in self.model.state_dict().items(): - if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): - continue - full_tensor = ( - tensor.full_tensor() if isinstance(tensor, DTensor) else tensor - ) # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective - adapted_fqn_tensors = _maybe_adapt_tensor_to_hf( - self.model, name, full_tensor - ) - for adapted_fqn, adapted_tensor in adapted_fqn_tensors: - state_dict_info[adapted_fqn] = (adapted_tensor.shape, self.dtype) + state_dict_info[name] = (tensor.shape, self.dtype) return state_dict_info @@ -898,41 +2000,9 @@ def stream_weights_via_ipc_zmq( from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl - # Use the shared implementation - stream_weights_via_ipc_zmq_impl( - params_generator=dtensor_params_generator(self.model, self.dtype), - buffer_size_bytes=buffer_size_bytes, - zmq_socket=self.zmq_socket, - rank=self.rank, - worker_name=str(self), - ) - - @torch.no_grad() - @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_http") - def stream_weights_via_http( - self, - sglang_url_to_gpu_uuids: dict[str, list[str]], - ) -> None: - """Stream model weights to SGLang servers via HTTP API. - - Args: - sglang_url_to_gpu_uuids: Dict mapping SGLang server URL to list of GPU UUIDs it uses - """ - # Manually move model to cuda for cpu offload case - if self.cpu_offload: - self.model = self.move_to_cuda(self.model) - - from nemo_rl.models.policy.utils import stream_weights_via_http_impl - - # Get current GPU UUID - current_device_uuid = self.report_device_id() - def dtensor_params_generator(): """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" - state_dict_items = sorted( - self.model.state_dict().items(), key=lambda x: x[0] - ) - for name, tensor in state_dict_items: + for name, tensor in self.model.state_dict().items(): if isinstance(tensor, DTensor): # Convert DTensor to full tensor for streaming full_tensor = tensor.full_tensor() @@ -945,13 +2015,13 @@ def dtensor_params_generator(): # Convert to target dtype yield name, tensor.to(self.dtype, non_blocking=True).contiguous() - # Use the HTTP implementation - stream_weights_via_http_impl( + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( params_generator=dtensor_params_generator(), - sglang_url_to_gpu_uuids=sglang_url_to_gpu_uuids, + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, rank=self.rank, worker_name=str(self), - current_device_uuid=current_device_uuid, ) @torch.no_grad() @@ -972,11 +2042,17 @@ def broadcast_weights_for_collective( ) self.model = self.move_to_cuda(self.model) + def _dtensor_post_iter_func(tensor, dtype): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(dtype, non_blocking=True) + return tensor + # param_iterator will return (name, tensor), we only need tensor - dtensor_post_iter_func = lambda x: x[1] + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) packed_broadcast_producer( - iterator=dtensor_params_generator(self.model, self.dtype), + iterator=iter(self.model.state_dict().items()), group=self.model_update_group, src=0, post_iter_func=dtensor_post_iter_func, @@ -1102,7 +2178,7 @@ def save_checkpoint( optimizer=self.optimizer, optimizer_path=optimizer_path, scheduler=self.scheduler, - tokenizer=self.tokenizer if tokenizer_path else None, + tokenizer=self.tokenizer if tokenizer_path is None else None, tokenizer_path=tokenizer_path, checkpointing_cfg=checkpointing_cfg, lora_enabled=self.lora_enabled, @@ -1141,16 +2217,10 @@ def _init_checkpoint_manager( self.checkpoint_manager = AutomodelCheckpointManager( dp_mesh=self.dp_mesh, tp_mesh=self.tp_mesh, + model_state_dict_keys=getattr(self, "model_state_dict_keys", None), moe_mesh=self.moe_mesh, ) self.checkpoint_manager.init_checkpointer( config_updates=config_updates, checkpoint_root=checkpoint_root, ) - - -@ray.remote( - runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") -) # pragma: no cover -class DTensorPolicyWorkerV2(DTensorPolicyWorkerV2Impl): - pass diff --git a/nemo_rl/models/policy/workers/dtensor_sharath.py b/nemo_rl/models/policy/workers/dtensor_sharath.py new file mode 100644 index 0000000000..70025cdea2 --- /dev/null +++ b/nemo_rl/models/policy/workers/dtensor_sharath.py @@ -0,0 +1,2762 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import contextlib +import gc +import itertools +import os +import warnings +from collections import defaultdict +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any, Generator, Iterable, Optional, Set, Union, cast + +import ray +import torch +import zmq +from accelerate import init_empty_weights +from torch import nn +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, +) +from torch.distributed.fsdp import ( + FSDPModule, +) +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from torch.distributed.tensor.experimental import context_parallel +from torch.distributed.tensor.experimental._attention import ( + set_rotate_method, +) +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, +) +from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + +from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + allgather_cp_sharded_tensor, + distributed_vocab_topk, + get_logprobs_from_vocab_parallel_logits, + _compute_distributed_log_softmax, +) +from nemo_rl.models.dtensor.parallelize import ( + _parallelize_model, + clip_grad_by_total_norm_, + get_grad_norm, + to_local_if_dtensor, +) +from nemo_rl.models.huggingface.common import ( + get_flash_attention_kwargs, + pack_sequences, +) +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.interfaces import ( + LogprobOutputSpec, + ReferenceLogprobOutputSpec, + ScoreOutputSpec, +) +from nemo_rl.models.policy.utils import ( + configure_dynamo_cache, + get_gpu_info, + get_runtime_env_for_policy_worker, + import_class_from_path, + resolve_model_class, + sliding_window_overwrite, + get_handle_from_tensor, + rebuild_cuda_tensor_from_ipc, +) +from nemo_rl.utils.native_checkpoint import ( + load_checkpoint, + save_checkpoint, +) +from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import packed_broadcast_producer + + +@contextmanager +def unshard_fsdp2_model(model: nn.Module) -> Generator[None, None, None]: + """Explicitly unshard and then reshard the FSDP2 modules. Useful for logprob inference.""" + try: + for module in model.modules(): + if isinstance(module, FSDPModule): + module.unshard() + yield + finally: + for module in model.modules(): + if isinstance(module, FSDPModule): + module.reshard() + +@torch.no_grad() +def get_cpu_state_dict( + state_generator: Iterable[tuple[str, Union[torch.Tensor, DTensor]]], + pin_memory: bool = False, +) -> dict[str, torch.Tensor]: + """Copy the state dict generator to CPU memory. + + Args: + state_generator (Iterable[tuple[str, Union[torch.Tensor, DTensor]]]): + An iterable that yields (key, tensor) pairs from a model state. + pin_memory (bool, optional): + Whether to allocate the CPU tensors in pinned memory for faster GPU transfer. + Defaults to False. + + Returns: + dict[str, torch.Tensor]: A dictionary mapping parameter names to CPU tensors. + """ + new_state_dict = {} + for k, v in state_generator: + val = to_local_if_dtensor(v) + + if len(val.shape) == 0: + new_state_dict[k] = val.cpu() + else: + cpu_tensor = torch.empty( + *val.shape, device="cpu", pin_memory=pin_memory, dtype=val.dtype + ) + cpu_tensor.copy_(val, non_blocking=True) + new_state_dict[k] = cpu_tensor + + torch.cuda.synchronize() + return new_state_dict + + +@ray.remote( + runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker") +) # pragma: no cover +class DTensorPolicyWorker: + def __repr__(self) -> str: + """Customizes the actor's prefix in the Ray logs. + + This makes it easier to identify which worker is producing specific log messages. + """ + if torch.distributed.is_initialized(): + return f"{self.__class__.__qualname__}[rank={torch.distributed.get_rank()}]" + else: + return f"{self.__class__.__qualname__}" + + def __init__( + self, + config: PolicyConfig, + tokenizer: AutoTokenizer, + processor: Optional[AutoProcessor] = None, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + **kwargs: Any, + ): + """Initialize the DTensorPolicyWorker.""" + self.tokenizer = tokenizer + self.processor = processor + self.is_vlm = processor is not None + + print(f"Initializing DTensorPolicyWorker with is_vlm={self.is_vlm}") + + self.is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. + # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. + if not self.is_generation_colocated: + os.environ["NCCL_CUMEM_ENABLE"] = "1" + + # Disable dynamo autotune_local_cache to avoid crash when there's already a cache + # with different order of node_bundles + configure_dynamo_cache() + + self.cfg = config + # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call + torch.distributed.init_process_group(backend="nccl") + self.rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + model_name = self.cfg["model_name"] + + self.teacher_logits = None + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] + self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] + self.max_grad_norm = self.cfg["max_grad_norm"] + + if self.cfg["precision"] == "float32": + self.dtype = torch.float32 + elif self.cfg["precision"] == "bfloat16": + self.dtype = torch.bfloat16 + elif self.cfg["precision"] == "float16": + self.dtype = torch.float16 + else: + raise ValueError(f"Unknown precision: {self.cfg['precision']}") + + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") + self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"] + if self.enable_seq_packing: + assert not self.is_vlm, ( + "Sequence packing is not supported for VLM models. Please set policy.sequence_packing.enabled = False to train VLM models." + ) + print( + f"[Rank {self.rank}] Sequence packing is enabled for model {model_name}" + ) + print(f"[Rank {self.rank}] Using FlashAttention2 for sequence packing") + + hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {} + + model_config = AutoConfig.from_pretrained( + model_name, + # Always load the model in float32 to keep master weights in float32. + # Keeping the master weights in lower precision has shown to cause issues with convergence. + torch_dtype=torch.float32, + trust_remote_code=True, + **sliding_window_overwrite( + model_name + ), # due to https://github.com/huggingface/transformers/issues/38002 + attn_implementation="flash_attention_2" + if self.enable_seq_packing + else None, + **hf_config_overrides, + ) + + # diffusion model + self._is_mdlm = self.cfg.get("is_mdlm", False) + self._is_dqwn = self.cfg.get("is_dqwn", False) + + # reward model + self._is_reward_model = ( + "reward_model_cfg" in self.cfg and self.cfg["reward_model_cfg"]["enabled"] + ) + if self._is_reward_model: + # Ensure sequence packing is disabled. + if self.enable_seq_packing: + raise NotImplementedError( + "Sequence packing is not supported for reward models" + ) + # Load model as a Reward Model. + rm_type = self.cfg["reward_model_cfg"]["reward_model_type"] + if rm_type == "bradley_terry": + model_class = AutoModelForSequenceClassification + if model_config.num_labels != 1: + # For Bradley-Terry reward models, the linear head has a single output. + # In the transformers library, the default setting for model_config.num_labels is 2 + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/configuration_utils.py#L259). + # Since num_labels is used as the out_features for the linear head + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/modeling_llama.py#L738) + # if num_labels is not 1, we set it to 1. This change may trigger a warning that some weights are not initialized + # from the model checkpoint and are instead initialized using model_config.initializer_range + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/configuration_llama.py#L62). + print( + "model_config.num_labels is not 1. Setting it to 1 since this value is used as the out_features " + "for the linear head of Bradley-Terry reward models." + ) + model_config.num_labels = 1 + else: + raise ValueError(f"Unknown reward model type: {rm_type}") + else: + # Check if the model requires AutoModel instead of AutoModelForCausalLM + if model_name in ["nvidia/Nemotron-Diffusion-Research-4B-v0", "nvidia/Nemotron-Diffusion-Research-8B-v0"]: + print(f"[Rank {self.rank}] Model {model_name} is not a causal LM, using AutoModel instead.") + from transformers import AutoModel + model_class = AutoModel + if "mdlm" in self.cfg and self.cfg["mdlm"].get("use_block_diff", False): + model_config.dlm_paradigm = "block_diff" + if "mdlm" in self.cfg and self.cfg["mdlm"].get("block_size") is not None: + model_config.block_size = self.cfg["mdlm"]["block_size"] + #model_config.seq_length = self.cfg["max_total_sequence_length"] + elif model_name in ["nvidia/Nemotron-Diffusion-Exp-Ministral-8B", "nvidia/Nemotron-Diffusion-Exp-Ministral-3B"]: + print(f"[Rank {self.rank}] Model {model_name} is not a causal LM, using AutoModel instead.") + from transformers import AutoModel + model_class = AutoModel + if "mdlm" in self.cfg and self.cfg["mdlm"].get("use_block_diff", False): + model_config.dlm_paradigm = "sbd_block_diff" + if "mdlm" in self.cfg and self.cfg["mdlm"].get("block_size") is not None: + model_config.block_size = self.cfg["mdlm"]["block_size"] + print("set block size.........") + # import pdb; pdb.set_trace() + #model_config.seq_length = self.cfg["max_total_sequence_length"] + else: + # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc. + model_class = resolve_model_class(model_config.model_type) + + full_state_dict = None + embed_weight_for_local = None + if self.rank == 0: + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") + model = model_class.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + trust_remote_code=True, + config=model_config, + ) + full_state_dict = model.state_dict() + + # Extract embedding weight for local copy (before FSDP wrapping) + for key in full_state_dict: + # added 'backbone.embeddings.weight' for nano teacher model + if "embed_tokens.weight" in key or "embeddings.weight" in key: + embed_weight_for_local = full_state_dict[key].clone() + print(f"[Rank {self.rank}] Found embedding weight: {key}, shape={embed_weight_for_local.shape}") + break + + del model + + print(f"[Rank {self.rank}] Initializing empty model for FSDP...") + # All ranks initialize model on meta device, so FSDP can shard it. + # The actual weights will be broadcast from rank 0. + with init_empty_weights(): + self.model = model_class.from_config( + model_config, + trust_remote_code=True, + ) + + if self.model.config.pad_token_id is None: + self.model.config.pad_token_id = tokenizer.pad_token_id + + tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] + cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + if cp_size > 1 and self.enable_seq_packing: + raise ValueError( + "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) + dp_size = world_size // tp_size // cp_size + sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"] + assert world_size == dp_size * tp_size * cp_size, ( + f"World size({world_size}) must equal to dp_size({dp_size}) * tp_size({tp_size}) * cp_size({cp_size}) to use DTensor" + ) + + if sequence_parallel_enabled and tp_size == 1: + print( + "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism." + ) + elif sequence_parallel_enabled and tp_size > 1: + raise RuntimeError( + "Sequence parallel + tp_size >1 is currently broken in torch==2.8.0. See https://github.com/NVIDIA-NeMo/Automodel/issues/652 for more details." + ) + + if cp_size > 1: + assert not isinstance(self.model, Gemma3ForCausalLM), ( + "Context parallel is not supported for Gemma3ForCausalLM. Torch context parallel has many limitations. " + "Please refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) + + assert not (tp_size > 1 and sequence_parallel_enabled), ( + "It's a known issue that context parallel can't be used together with sequence parallel in DTensor worker. " + "Please either set cp_size = 1 or disable sequence parallel. " + "See https://github.com/NVIDIA-NeMo/RL/issues/659 for more details." + ) + + assert not self.is_vlm, ( + "Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models." + ) + + # torch==2.8 uses LOCAL_RANK to set the device here (https://github.com/pytorch/pytorch/blob/ba56102387ef21a3b04b357e5b183d48f0afefc7/torch/distributed/device_mesh.py#L500), + # but CUDA_VISIBLE_DEVICES is set to only 1 gpu, so we need to temporarily set LOCAL_RANK to 0. + # TODO: consider changing the default LOCAL_RANK set in worker_groups.py + prev_local_rank = os.environ["LOCAL_RANK"] + os.environ["LOCAL_RANK"] = "0" + + device_mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", (dp_size, cp_size, tp_size), mesh_dim_names=("dp", "cp", "tp") + ) + os.environ["LOCAL_RANK"] = prev_local_rank + + self.dp_cp_mesh = device_mesh[("dp", "cp")]._flatten(mesh_dim_name="dp_cp") + + self.dp_mesh, self.tp_mesh, self.cp_mesh = ( + device_mesh["dp"], + device_mesh["tp"], + device_mesh["cp"], + ) + self.dp_size = dp_size + self.tp_size = tp_size + self.cp_size = cp_size + self.device_mesh = device_mesh + + # ------------------------------------------------ + # 3) Move to GPU + Composable FSDP + # (Initialize device mesh, shard submodules, then shard entire model) + # ------------------------------------------------ + self.model = _parallelize_model( + self.model, + self.dp_cp_mesh, + self.tp_mesh, + param_dtype=self.dtype, + sequence_parallel=sequence_parallel_enabled, + cpu_offload=self.cpu_offload, + activation_checkpointing=self.cfg["dtensor_cfg"][ + "activation_checkpointing" + ], + custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], + ) + + print(f"[Rank {self.rank}] Loading state dict from rank 0...") + # This will broadcast the state dict from rank 0 to all other ranks + # and load it into the FSDP model. + set_model_state_dict( + self.model, + model_state_dict=full_state_dict, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + ), + ) + + # Handle tied word embeddings after loading the state dict + # We need to actually tie the parameters at the model level + is_tied_lm_head = hasattr(self.model, "lm_head") and getattr( + getattr(self.model, "config", {}), "tie_word_embeddings", False + ) + if is_tied_lm_head: + embed_tokens_weight = None + for name, param in self.model.named_parameters(): + if "embed_tokens" in name and name.endswith(".weight"): + embed_tokens_weight = param + break + + if embed_tokens_weight is not None: + self.model.lm_head.weight = embed_tokens_weight + + # Manually broadcast buffers + for _, buf in self.model.named_buffers(): + torch.distributed.broadcast(to_local_if_dtensor(buf), src=0) + + # Create local embedding layer (for latent thinking, avoids DTensor issues) + # First broadcast the embedding weight shape, then the weight itself + if self.rank == 0: + embed_shape = torch.tensor(list(embed_weight_for_local.shape), dtype=torch.long, device="cuda") + else: + embed_shape = torch.zeros(2, dtype=torch.long, device="cuda") + torch.distributed.broadcast(embed_shape, src=0) + vocab_size, hidden_dim = int(embed_shape[0].item()), int(embed_shape[1].item()) + + if self.rank != 0: + embed_weight_for_local = torch.zeros(vocab_size, hidden_dim, dtype=self.dtype, device="cuda") + else: + embed_weight_for_local = embed_weight_for_local.to(dtype=self.dtype, device="cuda") + torch.distributed.broadcast(embed_weight_for_local, src=0) + + self.embed_layer_local = nn.Embedding(vocab_size, hidden_dim) + self.embed_layer_local.weight.data.copy_(embed_weight_for_local.cpu()) + self.embed_layer_local.weight.requires_grad = False + print(f"[Rank {self.rank}] Created local embedding: vocab_size={vocab_size}, hidden_dim={hidden_dim}") + del embed_weight_for_local + + if self.cpu_offload: + self.model = self.move_to_device(self.model, "cpu") + + if init_reference_model: + self.reference_model_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) + + if init_optimizer: + optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), **self.cfg["optimizer"]["kwargs"] + ) + else: + self.optimizer = None + + if "scheduler" in self.cfg and self.optimizer is not None: + if isinstance(self.cfg["scheduler"], dict): + scheduler_cls = import_class_from_path( + cast(str, self.cfg["scheduler"]["name"]) + ) + self.scheduler = scheduler_cls( + self.optimizer, **self.cfg["scheduler"]["kwargs"] + ) + else: + schedulers = [] + for scheduler_cfg in self.cfg["scheduler"]: + if "name" in scheduler_cfg: + schedulers.append( + import_class_from_path(scheduler_cfg["name"])( + self.optimizer, **scheduler_cfg["kwargs"] + ) + ) + else: + assert "milestones" in scheduler_cfg, ( + "unknown scheduler config: ", + scheduler_cfg, + ) + milestones: list[int] = scheduler_cfg["milestones"] + + self.scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, schedulers, milestones + ) + + elif self.optimizer is not None: + ## default to a passthrough LR schedule + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: 1 + ) + + # restore + if weights_path: + self.load_checkpoint(weights_path, optimizer_path) + else: + print( + "No weights path provided. Starting from scratch (default policy init)" + ) + + # Refer to nemo impl. Below is original comment. + # based on https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/utils.py#L113 + @staticmethod + def create_context_parallel_ctx( + cp_mesh: torch.distributed.device_mesh.DeviceMesh, + cp_buffers: list[torch.Tensor], + cp_seq_dims: list[int], + cp_no_restore_buffers: Set[torch.Tensor], + cp_rotate_method: Optional[str] = None, + ): + """Create a context parallel context. + + Args: + cp_mesh (DeviceMesh): The device mesh for context parallel. + cp_buffers (list[torch.Tensor]): The buffers for context parallel. + cp_seq_dims (list[int]): The sequence dimensions for context parallel. + cp_no_restore_buffers (Set[torch.Tensor]): The no restore buffers for context parallel. + cp_rotate_method (str): The rotation method for context parallel, such as "allgather" or "addtoall". + """ + if cp_rotate_method is not None: + set_rotate_method(cp_rotate_method) + + return context_parallel( + cp_mesh, + buffers=cp_buffers, + buffer_seq_dims=cp_seq_dims, + no_restore_buffers=cp_no_restore_buffers, + ) + + # Refer to nemo impl. Below is original comment. + # based on https://github.com/pytorch/torchtitan/blob/cddd7dc809f36fe0ed51cdaaea0671c084d75442/torchtitan/distributed/utils.py#L178 + + def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: + if "generation" in self.cfg and self.cfg["generation"] is not None: + logits.div_(self.cfg["generation"]["temperature"]) + return logits + + @staticmethod + @contextlib.contextmanager + def train_context(cp_context: Optional[Generator[None, None, None]] = None): + with contextlib.ExitStack() as stack: + if cp_context is not None: + from torch.nn.attention import SDPBackend, sdpa_kernel + # TODO (xilunwu): support cuDNN backend + + stack.enter_context( + sdpa_kernel( + [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] + ) + ) + + stack.enter_context(cp_context) + + yield + + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> None: + """Initialize the collective communication.""" + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=self.rank, world_size=world_size + ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) + + def is_alive(self) -> bool: + return True + + def reset_peak_memory_stats(self) -> None: + torch.cuda.reset_peak_memory_stats() + + def get_gpu_info(self) -> dict[str, Any]: + """Return information about the GPU being used by this worker.""" + return get_gpu_info(self.model) + + @wrap_with_nvtx_name("dtensor_policy_worker/train_latent_thinking") + def train_latent_thinking( + self, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + num_iterations: int, + detach_across_iterations: bool = True, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + soft_mixing: bool = False, + soft_mixing_top_k: int = 5, + mask_token_id: Optional[int] = None, + ) -> dict[str, Any]: + """Train with latent thinking for diffusion LM. + + This method implements iterative training where: + - Thinking tokens use hidden features from previous iteration + - Answer tokens are always masked + - Loss is computed at each iteration only on answer tokens + + Args: + data: BatchedDataDict with: + - input_ids: Token IDs with thinking replaced by mask tokens + - token_ids: Original target token IDs + - thinking_mask: Boolean mask for thinking positions + - answer_mask: Boolean mask for answer positions + - token_mask: Mask for loss computation (only answer tokens) + - sample_mask: Sample validity mask + loss_fn: Loss function to use + num_iterations: Number of iterations (k) + detach_across_iterations: If True, detach gradients between iterations + gbs: Global batch size + mbs: Micro batch size + soft_mixing: If True, build thinking_hidden_features as weighted average of + top-k token embeddings mixed with mask token embedding. The mask token + weight is (1 - top1_confidence) and top-k tokens share the remaining + top1_confidence proportionally. + soft_mixing_top_k: Number of top tokens to use for soft mixing (default: 5) + mask_token_id: Token ID for mask token (required if soft_mixing=True) + + Returns: + Training metrics dictionary + """ + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs + + sequence_dim = 1 + + if eval_mode: + self.model.eval() + else: + self.model.train() + data.to("cuda") + + losses = [] + all_mb_metrics = [] + + for gb_idx in range(num_global_batches): + global_batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + + assert "sample_mask" in global_batch, "sample_mask must be present in the data!" + local_valid_seqs = torch.sum(global_batch["sample_mask"]) + + local_valid_toks = torch.sum( + global_batch["token_mask"] * global_batch["sample_mask"].unsqueeze(-1) + ) + + to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda() + torch.distributed.all_reduce(to_reduce, group=self.dp_mesh.get_group()) + global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1] + + self.optimizer.zero_grad() + mb_losses = [] + batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + + # Setup microbatch iterator based on sequence packing + if self.enable_seq_packing: + mb_iterator = batch.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = batch.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + mb_iterator = itertools.chain(mb_iterator, dummy_iterator) + else: + mb_iterator = batch.make_microbatch_iterator(mbs) + iterator_len = batch.size // mbs + + for mb_idx, mb in enumerate(mb_iterator): + # Get the required tensors + # input_ids has thinking AND answer tokens set to mask tokens + input_ids = mb.get("input_ids").cuda() + # target_ids has answer tokens as original (for loss computation) + target_ids = mb.get("target_ids").cuda() if "target_ids" in mb else input_ids + thinking_mask = mb.get("thinking_mask").cuda() # Boolean mask for thinking positions + token_mask = mb.get("token_mask").cuda() # Loss mask (answer tokens only) + + batch_size, seq_len = input_ids.shape + + # Handle sequence packing + if self.enable_seq_packing: + input_lengths = mb["input_lengths"] + packed_sequence_size = [len(input_lengths)] + + # Pack input_ids + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=packed_sequence_size, + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"]["train_mb_tokens"], + ) + + # Pack target_ids + target_ids, _, _ = pack_sequences( + input_ids=target_ids, + input_lengths=input_lengths, + packed_sequence_size=packed_sequence_size, + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"]["train_mb_tokens"], + ) + + # Pack thinking_mask (convert to int for packing, then back to bool) + thinking_mask_int = thinking_mask.int() + thinking_mask_int, _, _ = pack_sequences( + input_ids=thinking_mask_int, + input_lengths=input_lengths, + packed_sequence_size=packed_sequence_size, + padding_value=0, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"]["train_mb_tokens"], + ) + thinking_mask = thinking_mask_int.bool() + + # Pack token_mask + token_mask, _, _ = pack_sequences( + input_ids=token_mask, + input_lengths=input_lengths, + packed_sequence_size=packed_sequence_size, + padding_value=0, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"]["train_mb_tokens"], + ) + + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs(input_lengths=input_lengths) + else: + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + position_ids = torch.arange(seq_len, device=input_ids.device).repeat(batch_size, 1) + flash_attn_kwargs = {} + + # Initialize hidden features for thinking positions (None for first iteration) + thinking_hidden_features = None + + # Track metrics across iterations + iteration_metrics = [] + total_loss_value = 0.0 + + for iter_idx in range(num_iterations): + with torch.autocast(device_type="cuda", dtype=self.dtype): + # Build input embeddings + # input_ids already has mask tokens at thinking AND answer positions + # Use self.embed_layer_local - a non-parallelized copy of embedding layer + # This avoids DTensor mixing issues with FSDP2 + assert self.embed_layer_local is not None, "Local embedding layer not initialized" + self.embed_layer_local = self.embed_layer_local.to(input_ids.device) + base_embeddings = self.embed_layer_local(input_ids) + + if iter_idx > 0 and thinking_hidden_features is not None: + # For subsequent iterations, use hidden features at thinking positions + # thinking_hidden_features shape: [batch, seq_len, hidden_dim] + # thinking_mask shape: [batch, seq_len] + + # Replace thinking positions with hidden features from previous iteration + inputs_embeds = base_embeddings.clone() + thinking_mask_expanded = thinking_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = torch.where( + thinking_mask_expanded, + thinking_hidden_features, + inputs_embeds + ) + else: + inputs_embeds = base_embeddings + + # Forward pass with inputs_embeds instead of input_ids + model_args = dict( + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + output_hidden_states=True, # Need hidden states for next iteration + flash_attn_kwargs=flash_attn_kwargs, + ) + + # Handle MDLM/DiffQwen specific args + if self._is_dqwn: + del model_args["use_cache"] + del model_args["attention_mask"] + del model_args["position_ids"] + del model_args["flash_attn_kwargs"] + model_args["masked_indices"] = token_mask + model_args["labels"] = target_ids + model_args["skip_loss"] = True + elif self._is_mdlm: + del model_args["position_ids"] + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + # Get logits first (needed for both loss and soft_mixing) + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + + # Get hidden states for next iteration + if soft_mixing: + # Build thinking_hidden_features as weighted average of top-k token embeddings + # mixed with mask token embedding + assert mask_token_id is not None, "mask_token_id must be provided when soft_mixing=True" + + # Use self.embed_layer_local for embedding lookup (no DTensor issues) + self.embed_layer_local = self.embed_layer_local.to(logits.device) + embed_weight_for_mixing = self.embed_layer_local.weight.data + + # Compute probabilities from logits + probs = torch.softmax(logits, dim=-1) # [batch, seq_len, vocab_size] + + # Get top-k tokens and their probabilities + top_k_probs, top_k_indices = torch.topk(probs, k=soft_mixing_top_k, dim=-1) # [batch, seq_len, k] + + # Get top-1 confidence (probability of the most likely token) + top1_confidence = top_k_probs[..., 0:1] # [batch, seq_len, 1] + + # Normalize top-k weights to sum to top1_confidence + # weight for each top-k token = (prob_i / sum(top_k_probs)) * top1_confidence + top_k_sum = top_k_probs.sum(dim=-1, keepdim=True) # [batch, seq_len, 1] + top_k_weights = (top_k_probs / (top_k_sum + 1e-8)) * top1_confidence # [batch, seq_len, k] + + # Get mask token embedding + mask_embedding = embed_weight_for_mixing[mask_token_id] # [hidden_dim] + + # Weight for mask embedding is (1 - top1_confidence) + mask_weight = 1.0 - top1_confidence # [batch, seq_len, 1] + + # Get embeddings for top-k tokens + # top_k_indices: [batch, seq_len, k] + batch_size_mix, seq_len_mix, k = top_k_indices.shape + hidden_dim = embed_weight_for_mixing.shape[-1] + + # Flatten indices for indexing + flat_top_k_indices = top_k_indices.view(-1) # [batch * seq_len * k] + top_k_embeddings_flat = embed_weight_for_mixing[flat_top_k_indices] # [batch * seq_len * k, hidden_dim] + top_k_embeddings = top_k_embeddings_flat.view(batch_size_mix, seq_len_mix, k, hidden_dim) # [batch, seq_len, k, hidden_dim] + + # Compute weighted sum of top-k embeddings + # top_k_weights: [batch, seq_len, k] -> [batch, seq_len, k, 1] + top_k_weights_expanded = top_k_weights.unsqueeze(-1) # [batch, seq_len, k, 1] + weighted_top_k = (top_k_embeddings * top_k_weights_expanded).sum(dim=2) # [batch, seq_len, hidden_dim] + + # Add weighted mask embedding + # mask_embedding: [hidden_dim] -> broadcast to [batch, seq_len, hidden_dim] + # mask_weight: [batch, seq_len, 1] + weighted_mask = mask_weight * mask_embedding # [batch, seq_len, hidden_dim] + + # Final thinking_hidden_features is the sum + thinking_hidden_features = weighted_top_k + weighted_mask # [batch, seq_len, hidden_dim] + else: + # Original approach: use hidden states from forward pass + if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None: + # Use the last hidden state + thinking_hidden_features = outputs.hidden_states + else: + raise ValueError("Could not find hidden states in outputs") + + del outputs + + # Detach hidden features if needed (for memory efficiency) + if detach_across_iterations and thinking_hidden_features is not None: + thinking_hidden_features = thinking_hidden_features.detach() + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + # Compute loss only on answer tokens + # Create a data dict for the loss function + loss_data = BatchedDataDict({ + "input_ids": target_ids, # Target tokens + "token_mask": token_mask, # Only answer tokens + "sample_mask": mb["sample_mask"].cuda(), + "p_mask": mb["p_mask"].cuda(), + "seq_index": mb["seq_index"].cuda() if "seq_index" in mb else None, + }) + + # Use sequence packing loss wrapper if enabled + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = loss_fn + + loss, loss_metrics = loss_fn_( + logits, + loss_data, + global_valid_seqs, + global_valid_toks, + ) + del logits + + iteration_metrics.append(loss_metrics) + + # Scale loss by dp_size for gradient averaging + scaled_loss = loss * self.dp_size * self.cp_size + total_loss_value += loss.item() + + if not eval_mode: + if detach_across_iterations: + # Backward after each iteration to save memory + # Gradients accumulate in model parameters + scaled_loss.backward() + else: + # Accumulate loss for single backward at the end + if iter_idx == 0: + accumulated_loss = scaled_loss + else: + accumulated_loss = accumulated_loss + scaled_loss + + # If not detaching, do single backward with accumulated loss + if not eval_mode and not detach_across_iterations: + accumulated_loss.backward() + + # Aggregate metrics across iterations + aggregated_metrics = {} + for m in iteration_metrics: + for k, v in m.items(): + if k not in aggregated_metrics: + aggregated_metrics[k] = 0 + aggregated_metrics[k] += v / num_iterations + + aggregated_metrics["num_iterations"] = num_iterations + aggregated_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + aggregated_metrics["global_valid_seqs"] = global_valid_seqs.item() + aggregated_metrics["global_valid_toks"] = global_valid_toks.item() + + mb_losses.append(total_loss_value) + all_mb_metrics.append(aggregated_metrics) + + # Gradient clipping and optimizer step + grad_norm = None + if not eval_mode: + with torch.no_grad(): + grad_norm = get_grad_norm( + self.model.parameters(), + dp_cp_group=self.dp_cp_mesh.get_group(), + tp_group=self.tp_mesh.get_group(), + dtype=torch.float32, + ) + if self.max_grad_norm is not None: + clip_grad_by_total_norm_( + self.model.parameters(), + max_grad_norm=self.max_grad_norm, + total_norm=grad_norm, + ) + grad_norm = torch.tensor([grad_norm]) + + self.optimizer.step() + losses.append(torch.tensor(mb_losses).sum().item()) + + # Clean up + self.optimizer.zero_grad() + if not eval_mode: + self.scheduler.step() + torch.cuda.empty_cache() + + # Compute global loss + with torch.no_grad(): + global_loss = torch.tensor(losses, device="cuda") + torch.distributed.all_reduce( + global_loss, group=self.dp_mesh.get_group() + ) + + # Aggregate metrics + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "grad_norm": grad_norm, + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": dict(mb_metrics), + } + + return metrics + + @wrap_with_nvtx_name("dtensor_policy_worker/train") + def train( + self, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + is_teacher: Optional[bool] = False, + teacher_logits: Optional = None, + topk_logits: Optional[int] = None, + ) -> dict[str, Any]: + """Train the policy on a batch of data with a given loss function.""" + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs + + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + + if is_teacher: + if self.teacher_logits is not None: + # With the reusable buffer approach, we no longer need to free old IPC handles. + # The same buffer and IPC handle are reused across iterations. + print(f" rank {torch.distributed.get_rank()} Teacher GPU memory at start: {torch.cuda.memory_allocated() / 1e9} {torch.cuda.memory_reserved() / 1e9}") + + if eval_mode: + ctx: AbstractContextManager[Any] = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + # Ensure model is in training mode + self.model.train() + + with ctx: + # Get data from batch and move to device + data.to("cuda") + + losses = [] + all_mb_metrics = [] + for gb_idx in range(num_global_batches): + global_batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + + assert "sample_mask" in global_batch, ( + "sample_mask must be present in the data!" + ) + ## get the normalization factor for the loss + local_valid_seqs = torch.sum(global_batch["sample_mask"]) + + if not "token_mask" in global_batch: + local_valid_toks = ( + local_valid_seqs * global_batch["input_ids"].shape[1] + ) + else: + if self._is_dqwn: + local_valid_toks = torch.sum( + global_batch["token_mask"] + * global_batch["sample_mask"].unsqueeze(-1) + ) + elif self._is_mdlm and "noise_mask" in global_batch: + local_valid_toks = torch.sum( + global_batch["noise_mask"] + * global_batch["sample_mask"].unsqueeze(-1) + ) + else: + local_valid_toks = torch.sum( + global_batch["token_mask"][:, 1:] + * global_batch["sample_mask"].unsqueeze(-1) + ) + + to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda() + torch.distributed.all_reduce(to_reduce, group=self.dp_mesh.get_group()) + global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1] + + global_valid_toks_ar = None + if self._is_dqwn and "token_mask_ar" in global_batch: + local_valid_toks_ar = torch.sum( + global_batch["token_mask_ar"] + * global_batch["sample_mask"].unsqueeze(-1) + ) + to_reduce = torch.tensor([local_valid_toks_ar]).cuda() + torch.distributed.all_reduce(to_reduce, group=self.dp_mesh.get_group()) + global_valid_toks_ar = to_reduce[0] + + if ( + hasattr(loss_fn, "loss_type") + and loss_fn.loss_type == LossType.TOKEN_LEVEL + ): + assert "token_mask" in global_batch, ( + "token_mask must be present in the data when using token-level loss" + ) + + if not is_teacher: + self.optimizer.zero_grad() + mb_losses = [] + batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = batch.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + iterator_len, max_seqlen = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = batch.make_microbatch_iterator(mbs) + iterator_len = batch.size // mbs + + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "clear_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) + + all_teacher_logits = None + + # Reconstruct teacher tensor from IPC ONCE before the microbatch loop. + # Previously, rebuild_cuda_tensor_from_ipc was called inside the loss function + # on every microbatch, opening a new cudaIpcOpenMemHandle mapping each time. + # This caused unnecessary memory pressure and potential IPC mapping leaks. + teacher_logits_tensor = None + teacher_topk_indices_tensor = None + if not is_teacher and teacher_logits is not None: + list_index = torch.distributed.get_rank() + dict_index = list_index + current_device_id = torch.cuda.current_device() + actual_shape = teacher_logits[list_index].get('actual_shape') if isinstance(teacher_logits[list_index], dict) else None + is_topk = teacher_logits[list_index].get('is_topk', False) if isinstance(teacher_logits[list_index], dict) else False + teacher_logits_tensor = rebuild_cuda_tensor_from_ipc( + teacher_logits[list_index][dict_index], current_device_id + ).detach() + # Slice to actual data dimensions if buffer was larger (grow-only IPC buffer) + if actual_shape is not None: + aB, aS, aV = actual_shape + teacher_logits_tensor = teacher_logits_tensor[:aB, :aS, :aV].clone() + else: + teacher_logits_tensor = teacher_logits_tensor.clone() + + # If topk mode, also reconstruct the indices tensor from its IPC buffer + if is_topk and 'topk_indices_ipc' in teacher_logits[list_index]: + teacher_topk_indices_tensor = rebuild_cuda_tensor_from_ipc( + teacher_logits[list_index]['topk_indices_ipc'], current_device_id + ).detach() + if actual_shape is not None: + teacher_topk_indices_tensor = teacher_topk_indices_tensor[:aB, :aS, :aV].clone() + else: + teacher_topk_indices_tensor = teacher_topk_indices_tensor.clone() + + for mb_idx, mb in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() + + with torch.autocast(device_type="cuda", dtype=self.dtype): + if self.enable_seq_packing: + input_ids = mb.get("input_ids").cuda() + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=mb["input_lengths"], + packed_sequence_size=[ + len(mb["input_lengths"]) + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"][ + "train_mb_tokens" + ], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled. + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=mb["input_lengths"], + ) + + else: + input_ids = mb.get("input_ids").cuda() + batch_size, seq_len = input_ids.shape + + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} + masked_indices = None + if self._is_dqwn and self._is_mdlm: + masked_indices = mb.get("token_mask").cuda() + + # add vlm kwargs to model call + vlm_kwargs = mb.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + if len(vlm_kwargs) > 0: + position_ids = None + assert not self.cfg["dtensor_cfg"]["sequence_parallel"], ( + "Sequence parallel is not supported with multimodal since there's an issue when you do not pass position_ids. See https://github.com/NVIDIA-NeMo/Automodel/issues/652" + ) + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + f"multimodal kwargs={vlm_kwargs} are not supported for context parallel" + ) + seq_index = torch.arange( + seq_len, device=input_ids.device + ).repeat(1, 1) + cp_buffers = ( + [input_ids, masked_indices, seq_index] if self._is_dqwn and self._is_mdlm else [input_ids, position_ids, seq_index] + if self.cp_size > 1 + else [] + ) + + # Create context parallel context + context_parallel_ctx = self.create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with DTensorPolicyWorker.train_context(context_parallel_ctx): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + if self._is_mdlm: + del model_args["position_ids"] + del model_args["flash_attn_kwargs"] + + if self._is_dqwn: + assert self._is_mdlm, "DiffQwen is only supported for MDLM" + + del model_args["use_cache"] + # internally uses block diffusion attention mask + del model_args["attention_mask"] + + model_args = { + **model_args, + #"masked_indices": mb["token_mask"], + "masked_indices": masked_indices, + #"p_mask": mb["p_mask"], + "labels": input_ids, + "skip_loss": True, + } + + if self._is_reward_model: + # `flash_attn_kwarg` is not supported for `LlamaForSequenceClassification`. + # Note that it should be empty anyway since sequence packing + # is not supported for reward models. + assert not flash_attn_kwargs + del model_args["flash_attn_kwargs"] + # remove flash_attn_kwargs if there are multimodal kwargs + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + # print("input_ids shape:", input_ids.shape) + + outputs = self.model(**model_args) + + # with torch.no_grad(is_teacher): + with torch.set_grad_enabled(not is_teacher): + # Get logprobs + causal_logits = None + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + if hasattr(outputs, "causal_logits") and outputs.causal_logits is not None: + causal_logits = outputs.causal_logits + del outputs + logits = torch.cat([logits, causal_logits], dim=1) + else: + del outputs + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + #if causal_logits is not None: + # causal_logits = self._apply_temperature_scaling(causal_logits) + + if self.cp_size > 1: + seq_index_dtensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + mb["seq_index"] = seq_index_dtensor + + + token_mask_dtensor_orig = ( + DTensor.from_local( + masked_indices, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + .full_tensor() + ) + + _, sorted_indices = torch.sort(seq_index_dtensor) + token_mask_dtensor = token_mask_dtensor_orig[:, sorted_indices] + + mb["token_mask"] = token_mask_dtensor + + for tensor_name in mb: + current_tensor = mb[tensor_name] + for buffer in cp_buffers: + if current_tensor is buffer: + assert type(current_tensor) == torch.Tensor, ( + f"tensor {tensor_name} is not a tensor" + ) + mb[tensor_name] = DTensor.from_local( + current_tensor, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + break + + # print("full logits shape prior to CP redistribution:", logits.shape) + # print("full logits shape prior to CP redistribution with local:", logits.to_local().shape) + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + + if is_teacher: + with torch.no_grad(): + if all_teacher_logits is None: + all_teacher_logits = logits + else: + all_teacher_logits = torch.cat([all_teacher_logits, logits], dim=0) + else: + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = loss_fn + + if teacher_logits_tensor is not None: + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + torch.cat([global_valid_toks.unsqueeze(0), global_valid_toks_ar.unsqueeze(0)], dim=0) if (causal_logits is not None and global_valid_toks_ar is not None) else global_valid_toks, + teacher_logits=teacher_logits_tensor, + mb_idx=mb_idx, + mbs=mbs, + teacher_topk_indices_ipc=teacher_topk_indices_tensor, + ) + else: + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + torch.cat([global_valid_toks.unsqueeze(0), global_valid_toks_ar.unsqueeze(0)], dim=0) if (causal_logits is not None and global_valid_toks_ar is not None) else global_valid_toks, + ) + + del logits, causal_logits + + # skip the update for dummy batches + if mb_idx < iterator_len: + ## scale by the number of global batches so we get the correct + ## value when summing metrics across all microbatches + for k in loss_metrics.keys(): + loss_metrics[k] /= num_global_batches + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["global_valid_seqs"] = global_valid_seqs.item() + loss_metrics["global_valid_toks"] = global_valid_toks.item() + else: + loss *= 0 + + # Backward pass + if not eval_mode: + ## NOTE: invalid samples should be multiplied + ## by zero in the loss function to prevent them + ## from affecting the gradient calculation + + # when FSDP reduces the gradients over the DP dim, they're automatically averaged + # but we want to sum them so we cancel out the average here + loss *= self.dp_size * self.cp_size + loss.backward() + + if not is_teacher: + if num_valid_samples > 0: + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + # end of all micro batches in the global batch + # Clean up pre-reconstructed teacher tensor after all microbatches + if teacher_logits_tensor is not None: + del teacher_logits_tensor + teacher_logits_tensor = None + if teacher_topk_indices_tensor is not None: + del teacher_topk_indices_tensor + teacher_topk_indices_tensor = None + + if is_teacher: + + with torch.no_grad(): + seq_index = data.get("seq_index", None) + + if isinstance(all_teacher_logits, DTensor): + all_teacher_logits_local = all_teacher_logits.to_local() # [B, S, V_local] + else: + all_teacher_logits_local = all_teacher_logits + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(all_teacher_logits_local.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + all_teacher_logits_local = all_teacher_logits_local.to(torch.float32) + all_teacher_logits_local_log_prob = _compute_distributed_log_softmax(all_teacher_logits_local, group=tp_group) + # Delete intermediate tensors to free memory immediately + del all_teacher_logits, all_teacher_logits_local + + if isinstance(all_teacher_logits_local_log_prob, DTensor): + all_teacher_logits_local_log_prob = all_teacher_logits_local_log_prob.to_local()#.cpu() + else: + all_teacher_logits_local_log_prob = all_teacher_logits_local_log_prob#.cpu() + + #Check if teacher is a MDLM model in order to extract only the causal logits for KD + if self.cfg['is_mdlm']: + # extract only the causal logits for KD + shared_sequence_length = int(all_teacher_logits_local_log_prob.shape[1] / 2) + all_teacher_logits_local_log_prob_causal = all_teacher_logits_local_log_prob[:, shared_sequence_length:, :] + del all_teacher_logits_local_log_prob + else: + all_teacher_logits_local_log_prob_causal = all_teacher_logits_local_log_prob + + if topk_logits is not None: + # ===== TOP-K PATH ===== + # Use distributed_vocab_topk on teacher's TP-local log probs + # to get the k most probable tokens and their (full-vocab- + # normalized) log probabilities. Both topk_logprobs and + # topk_indices are replicated across TP ranks. + topk_logprobs, topk_indices = distributed_vocab_topk( + all_teacher_logits_local_log_prob_causal, # [B, S, V_local] + k=topk_logits, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + # topk_logprobs: [B, S, k] (replicated, full-vocab log probs) + # topk_indices: [B, S, k] (global token ids, replicated) + del all_teacher_logits_local_log_prob_causal + + B, S, K = topk_logprobs.shape + target_dtype = topk_logprobs.dtype + target_device = topk_logprobs.device + + # Pre-allocate two IPC buffers (values + indices) exactly once. + if not hasattr(self, '_teacher_topk_vals_buffer') or self._teacher_topk_vals_buffer is None: + max_S = self.cfg.get("max_total_sequence_length", S) + vals_buf_shape = (B, max_S, K) + self._teacher_topk_vals_buffer = torch.empty( + vals_buf_shape, dtype=target_dtype, device=target_device + ) + self._teacher_topk_vals_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_vals_buffer) + } + idx_buf_shape = (B, max_S, K) + self._teacher_topk_idx_buffer = torch.empty( + idx_buf_shape, dtype=topk_indices.dtype, device=target_device + ) + self._teacher_topk_idx_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_idx_buffer) + } + print(f" rank {torch.distributed.get_rank()} Allocated topk IPC buffers: " + f"vals={vals_buf_shape} ({self._teacher_topk_vals_buffer.numel() * self._teacher_topk_vals_buffer.element_size() / 1e9:.4f} GB), " + f"idx={idx_buf_shape} ({self._teacher_topk_idx_buffer.numel() * self._teacher_topk_idx_buffer.element_size() / 1e9:.4f} GB) " + f"(actual data: [{B}, {S}, {K}])") + + # Copy actual data into the top-left slice of the buffers + self._teacher_topk_vals_buffer[:B, :S, :K].copy_(topk_logprobs) + self._teacher_topk_idx_buffer[:B, :S, :K].copy_(topk_indices) + del topk_logprobs, topk_indices + + rank = torch.distributed.get_rank() + self.teacher_logits = { + rank: self._teacher_topk_vals_ipc[rank], + 'actual_shape': (B, S, K), + 'topk_indices_ipc': self._teacher_topk_idx_ipc[rank], + 'is_topk': True, + } + return self.teacher_logits + + # ===== FULL-LOGPROB PATH (no top-k) ===== + # IPC buffer for teacher logits (knowledge distillation). + # Pre-allocate at max possible S dimension on first call so the buffer + # and its IPC handle are created exactly ONCE. Each subsequent call just + # copies the actual data into a slice. + # + # Why pre-allocate? Every call to get_handle_from_tensor → + # reduce_tensor → _share_cuda_ creates a CudaIPCSentData entry in + # PyTorch's C++ layer that holds a smart-pointer to the underlying + # allocation. That pointer prevents cudaFree even after all Python + # references are dropped, until the cross-process reference counter + # reaches 0 AND PyTorch's IPC GC runs. Growing the buffer means + # creating a new allocation + IPC handle while the old one is still + # pinned → leaked memory. Pre-allocating avoids the problem entirely. + B, S, V = all_teacher_logits_local_log_prob_causal.shape + target_dtype = all_teacher_logits_local_log_prob_causal.dtype + target_device = all_teacher_logits_local_log_prob_causal.device + + # Initialize buffer on first call — allocate at max S from config + if not hasattr(self, '_teacher_logits_buffer') or self._teacher_logits_buffer is None: + # Use max_total_sequence_length from config to compute the + # upper-bound S so we never need to grow. + max_S = self.cfg.get("max_total_sequence_length", S) + + # For MDLM models the causal logits are the second half of + # the sequence (see slicing above), so max causal S = max_S / 2. + # if self.cfg.get('is_mdlm', False): + # max_S = max_S // 2 + + # B and V are fixed across iterations + buf_shape = (B, max_S, V) + self._teacher_logits_buffer = torch.empty( + buf_shape, dtype=target_dtype, device=target_device + ) + self._teacher_logits_ipc_handle = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_logits_buffer) + } + print(f" rank {torch.distributed.get_rank()} Allocated IPC buffer: shape={buf_shape}, " + f"size={self._teacher_logits_buffer.numel() * self._teacher_logits_buffer.element_size() / 1e9:.4f} GB " + f"(actual data: [{B}, {S}, {V}])") + + # Copy actual data into the top-left slice of the buffer + self._teacher_logits_buffer[:B, :S, :V].copy_(all_teacher_logits_local_log_prob_causal) + del all_teacher_logits_local_log_prob_causal + + # Return the IPC handle dict with actual shape metadata + # Consumer must slice the reconstructed tensor to [:B, :S, :V] + self.teacher_logits = { + torch.distributed.get_rank(): self._teacher_logits_ipc_handle[torch.distributed.get_rank()], + 'actual_shape': (B, S, V), + } + + # print(f" rank {torch.distributed.get_rank()} Teacher GPU memory used: {torch.cuda.memory_allocated() / 1e9} {torch.cuda.memory_reserved() / 1e9}") + return self.teacher_logits + # return all_teacher_logits_local_log_prob#.cpu() + + # Clean up teacher_logits parameter after all microbatches (student side) + # Note: This is the IPC handle dict passed from the teacher, not self.teacher_logits + # if teacher_logits is not None: + # del teacher_logits + + grad_norm: Optional[float | torch.Tensor] = None + if not eval_mode: + with torch.no_grad(): + grad_norm = get_grad_norm( + self.model.parameters(), + dp_cp_group=self.dp_cp_mesh.get_group(), + tp_group=self.tp_mesh.get_group(), + dtype=torch.float32, + ) + if self.max_grad_norm is not None: + clip_grad_by_total_norm_( + self.model.parameters(), + max_grad_norm=self.max_grad_norm, + total_norm=grad_norm, + ) + grad_norm = torch.tensor([grad_norm]) + + # Update parameters + self.optimizer.step() + + losses.append(torch.tensor(mb_losses).sum().item()) + + # release gradient memory before rollouts + self.optimizer.zero_grad() + # increment scheduler after all batches in rollout are processed + if not eval_mode: + self.scheduler.step() + # dynamic batch and sequence dims causes alot of fragmentation, so clear + # the memory allocator before moving on + torch.cuda.empty_cache() + + # Compute global loss across all ranks + with torch.no_grad(): + global_loss = torch.tensor(losses, device="cuda") + torch.distributed.all_reduce( + global_loss, group=self.dp_mesh.get_group() + ) + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "grad_norm": grad_norm, + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": dict(mb_metrics), + } + + return metrics + + # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) + @wrap_with_nvtx_name("dtensor_policy_worker/get_logprobs") + def get_logprobs( + self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None + ) -> BatchedDataDict[LogprobOutputSpec]: + """Get the logprobs of the model for a batch of data. + + Uses the configured logprob_batch_size to do microbatching. + + Input data is assumed to be right-padded. The method internally converts to + left-padded format for computation, and returns outputs in right-padded format. + + Returns: + a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) + + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + + all_log_probs = [] + self.model.eval() + + with unshard_fsdp2_model(self.model), torch.no_grad(): + data.to("cuda") + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(logprob_batch_size) + iterator_len = data.size // logprob_batch_size + + step = 0 + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + vlm_kwargs = lp_batch.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + + batch_size, seq_len = input_ids.shape + if self.enable_seq_packing: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for sequence packing" + ) + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create post_attention_mask for right-padded data for masking token after forwarding. + post_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.bool, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + post_attention_mask[i, :length] = 1 + + # explicitly create position ids for the input, otherwise the sharding + # for DTensor will be incorrect + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} + + # DTensor requires the casual attention kernel to hit, + # yet our attention mask above is not always all 1s + # this is fine because we mask with the actual attention mask + # later, but for input it has to be all 1s + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + + # if there are multimodal kwargs, we don't need to add position_ids (computed internally) + if len(vlm_kwargs) > 0: + position_ids = None + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for context parallel" + ) + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = self.create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with DTensorPolicyWorker.train_context(context_parallel_ctx): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + logits = outputs.logits + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + if self.cp_size > 1: + seq_index_tensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + input_ids_dtensor = DTensor.from_local( + input_ids, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits, + input_ids_dtensor, + seq_index_tensor, + chunk_size=logprob_chunk_size, + ) + + assert token_logprobs.shape[1] == seq_len - 1 + else: + if isinstance(logits, DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits, + input_ids, + chunk_size=logprob_chunk_size, + ) + else: + if logprob_chunk_size is not None: + logits_seq_len = int(logits.shape[1]) + num_chunks = ( + logits_seq_len + logprob_chunk_size - 1 + ) // logprob_chunk_size + chunked_log_probs = [] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * logprob_chunk_size + chunk_end = min( + logits_seq_len, + (chunk_idx + 1) * logprob_chunk_size, + ) + chunk_logits = logits[ + :, chunk_start:chunk_end, : + ].to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + chunk_logits, dim=-1 + ) + chunked_log_probs.append(log_probs) + log_probs = torch.cat(chunked_log_probs, dim=1) + del chunked_log_probs + else: + logits = logits.to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + logits, dim=-1 + ) + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + del log_probs + + del outputs, logits + + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) + + # skip keeping the logprobs for the dummy batches + if batch_idx >= iterator_len: + continue + + if not self.enable_seq_packing: + # Apply mask to zero out padding tokens logprobs + token_logprobs = token_logprobs * post_attention_mask + else: + # For packed sequences, unpack logprobs + unpacked_logprobs = torch.zeros( + (batch_size, seq_dim_size), + dtype=token_logprobs.dtype, + device=token_logprobs.device, + ) + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + for i in range(batch_size): + start = cu_seqlens[i].item() + 1 + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + unpacked_logprobs[i, 1:seq_len_actual] = token_logprobs[ + 0, start:end + ] + token_logprobs = unpacked_logprobs + + all_log_probs.append(token_logprobs) + + # Concatenate all batches + return_data = BatchedDataDict[LogprobOutputSpec]() + + all_log_probs_padded = [] + for lp in all_log_probs: + padding_needed = seq_dim_size - lp.shape[1] + if padding_needed > 0: + lp = torch.nn.functional.pad( + lp, (0, padding_needed), mode="constant", value=0.0 + ) + all_log_probs_padded.append(lp) + return_data["logprobs"] = torch.cat(all_log_probs_padded, dim=0).cpu() + + return return_data + + # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) + @wrap_with_nvtx_name("dtensor_policy_worker/score") + def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: + global_batch_size = min(self.cfg["batch_size"], data.size) + + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + self.model.eval() + + with unshard_fsdp2_model(self.model), torch.no_grad(): + data.to("cuda") + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(global_batch_size) + iterator_len = data.size // global_batch_size + + step = 0 + all_rm_scores = [] + for batch_idx, generate_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 + input_ids = generate_batch.get("input_ids").cuda() + input_lengths = generate_batch.get("input_lengths") + batch_size, seq_len = input_ids.shape + if self.enable_seq_packing: + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create attention mask for right-padded data + post_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.bool, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + post_attention_mask[i, :length] = 1 + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = self.create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + with DTensorPolicyWorker.train_context(context_parallel_ctx): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ) + outputs = self.model(**model_args) + + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + if isinstance(logits, DTensor): + logits = logits.to(torch.float32) + else: + logits = outputs.logits.to(torch.float32) + + rm_scores = to_local_if_dtensor(logits) + rm_scores = rm_scores.squeeze(-1) + all_rm_scores.append(rm_scores) + + all_rm_scores = torch.cat(all_rm_scores, dim=0) + all_rm_scores = all_rm_scores.squeeze(-1).cpu() + return_data = BatchedDataDict[ScoreOutputSpec]( + { + "scores": all_rm_scores, + } + ) + return return_data + + @wrap_with_nvtx_name("dtensor_policy_worker/get_topk_logits") + def get_topk_logits( + self, + data: BatchedDataDict[Any], + k: int, + micro_batch_size: Optional[int] = None, + ) -> BatchedDataDict[Any]: + """Return per-position top-k logits and corresponding global indices. + + Notes: + - Return shapes are [B, S, k]. + - Computes top-k over the full sequence (no trimming of the last position). + - If alignment with next-token targets is required, the caller should handle it. + - If logits are TP-sharded DTensor, performs distributed global top-k across TP. + - Supports context parallelism with proper CP gather. + - Otherwise, computes local top-k on full-vocab tensor. + """ + print(f"config: {self.cfg}") + topk_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + + out_topk_vals = [] + out_topk_idx = [] + self.model.eval() + + with torch.no_grad(): + data.to("cuda") + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + # dynamic batching support (no CP/packed) + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(topk_batch_size) + iterator_len = data.size // topk_batch_size + + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + vlm_kwargs = lp_batch.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + batch_size, seq_len = input_ids.shape + + # Store original shapes for unpacking later + original_batch_size = batch_size + original_seq_len = seq_len + + if self.enable_seq_packing: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for sequence packing" + ) + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Build attention mask (right-padded inputs) + attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 + + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + flash_attn_kwargs = {} + + with torch.autocast(device_type="cuda", dtype=self.dtype): + attention_mask_input_all_ones = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + + # if there are multimodal kwargs, we don't need to add position_ids (computed internally) + if len(vlm_kwargs) > 0: + position_ids = None + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for context parallel" + ) + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = self.create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with DTensorPolicyWorker.train_context(context_parallel_ctx): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask_input_all_ones, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + del outputs + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + if self.cp_size > 1: + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + # deal with TP first + local_logits = logits.to_local() # [B, S_cp, V_tp] + + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + vals, idx = distributed_vocab_topk( + local_logits, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + # [B, S_cp, k] + + cp_group = self.cp_mesh.get_group() + + vals = allgather_cp_sharded_tensor( + vals, cp_group, seq_dim=sequence_dim + ) + idx = allgather_cp_sharded_tensor( + idx, cp_group, seq_dim=sequence_dim + ) + # [B, S, k] + else: + # Compute top-k over full sequence length (do not drop last position) + if isinstance(logits, DTensor): + local_logits = logits.to_local() # [B, S, V_local] + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + vals, idx = distributed_vocab_topk( + local_logits, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + else: + full_logits = logits.to(torch.float32) + vals, idx = torch.topk(full_logits, k=k, dim=-1) + + # Handle sequence packing unpacking + if self.enable_seq_packing: + # Unpack top-k results from packed format back to original batch format + # vals: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] + # idx: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] + + # Create tensors to store unpacked results + unpacked_vals = torch.zeros( + (original_batch_size, original_seq_len, k), + dtype=vals.dtype, + device=vals.device, + ) + unpacked_idx = torch.zeros( + (original_batch_size, original_seq_len, k), + dtype=idx.dtype, + device=idx.device, + ) + + # Get cumulative sequence lengths for unpacking + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + + for i in range(original_batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + + # Extract the corresponding portion from packed results + # Note: vals and idx are [1, packed_seq_len, k] due to packing + unpacked_vals[i, :seq_len_actual, :] = vals[0, start:end, :] + unpacked_idx[i, :seq_len_actual, :] = idx[0, start:end, :] + + # Replace with unpacked results + vals = unpacked_vals + idx = unpacked_idx + + # Update batch_size and seq_len for consistency + batch_size = original_batch_size + seq_len = original_seq_len + + # Keep only real sequence tokens (no trimming here; padded positions can be masked downstream) + # Shapes remain [B, S, k]. + out_topk_vals.append(vals.cpu()) + out_topk_idx.append(idx.cpu()) + + ret = BatchedDataDict[Any]() + # Pad each micro-batch result on sequence dim to common length (S), similar to get_logprobs + all_topk_vals_padded = [] + all_topk_idx_padded = [] + target_seq_len = seq_dim_size + for vals, idx in zip(out_topk_vals, out_topk_idx): + pad_needed = target_seq_len - vals.shape[1] + if pad_needed > 0: + # pad along sequence dimension (second dim): (last_dim_pad_left, last_dim_pad_right, seq_pad_left, seq_pad_right, batch_pad_left, batch_pad_right) + vals = torch.nn.functional.pad( + vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 + ) + idx = torch.nn.functional.pad( + idx, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0 + ) + all_topk_vals_padded.append(vals) + all_topk_idx_padded.append(idx) + + ret["topk_logits"] = ( + torch.cat(all_topk_vals_padded, dim=0) + if len(all_topk_vals_padded) > 1 + else all_topk_vals_padded[0] + ).cpu() + ret["topk_indices"] = ( + torch.cat(all_topk_idx_padded, dim=0) + if len(all_topk_idx_padded) > 1 + else all_topk_idx_padded[0] + ).cpu() + return ret + + @contextmanager + def use_reference_model(self) -> Generator[None, None, None]: + """Context manager that temporarily swaps the reference model and active model. + + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references + On exit: Restores original references and re-flips cuda/cpu + """ + with torch.no_grad(): + try: + # Save train model state_dict + curr_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) + + # Swap reference model state_dict to self.model + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(self.reference_model_state_dict[k]) + + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU + yield + + finally: + # Restore train model state_dict + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(curr_state_dict[k]) + + @wrap_with_nvtx_name("dtensor_policy_worker/get_reference_policy_logprobs") + def get_reference_policy_logprobs( + self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None + ) -> BatchedDataDict[ReferenceLogprobOutputSpec]: + """Get the logprobs from the reference policy for a batch of data. + + Returns: + a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + with self.use_reference_model(): + reference_logprobs = self.get_logprobs(data, micro_batch_size) + + return_data = BatchedDataDict[ReferenceLogprobOutputSpec]() + return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() + return return_data + + def _add_noise_to_weights(self) -> None: + """Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.""" + noise_std = 0.01 # Standard deviation for the noise + for p in self.model.parameters(): + if p.requires_grad: + noise = torch.randn_like(p.data) * noise_std + p.data.add_(noise) # Add noise in-place + torch.cuda.synchronize() + + def return_state_dict(self): + return self.model.state_dict() + + def return_model_config(self) -> dict[str, Any]: + """Return the model configuration as a dictionary. + + Returns: + dict: Model configuration dictionary + """ + return self.model.config + + def report_device_id(self) -> str: + """Report the UUID of the current CUDA device using NVML. + + Returns: + str: UUID of the device in the format "GPU-xxxxx" + """ + from nemo_rl.utils.nvml import get_device_uuid + + # Get current device index from torch + device_idx = torch.cuda.current_device() + # Get device UUID using NVML + return get_device_uuid(device_idx) + + def get_zmq_address(self): + """Get the ZMQ address for the current device.""" + return f"ipc:///tmp/{self.report_device_id()}.sock" + + def maybe_init_zmq(self): + """Initialize the ZMQ socket if it doesn't exist.""" + if not hasattr(self, "zmq_socket"): + self.zmq_context = zmq.Context() + self.zmq_socket = self.zmq_context.socket(zmq.REQ) + self.zmq_socket.setsockopt( + zmq.SNDTIMEO, 120000 + ) # set timeout to 120 seconds + self.zmq_socket.setsockopt( + zmq.RCVTIMEO, 120000 + ) # set timeout to 120 seconds + self.zmq_socket.setsockopt(zmq.LINGER, 0) + self.zmq_socket.bind(self.get_zmq_address()) + + @torch.no_grad() + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare state dict metadata for weight refitting and IPC streaming.""" + state_dict_info = {} + for name, tensor in self.model.state_dict().items(): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) + + return state_dict_info + + def get_free_memory_bytes(self) -> int: + """Get the available free memory.""" + from nemo_rl.utils.nvml import get_free_memory_bytes + + device_idx = torch.cuda.current_device() + return get_free_memory_bytes(device_idx) + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" + self.maybe_init_zmq() + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + self.model = self.move_to_cuda(self.model) + + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() + + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=dtensor_params_generator(), + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), + ) + + @torch.no_grad() + def broadcast_weights_for_collective(self) -> None: + """Broadcast the weights for collective communication.""" + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + print( + "[WARNING]: Unless you are lacking of memory, it is not recommended to enable cpu_offload when " + "using non-colocated generation since it will have an extra onload and offload at refit stage." + ) + self.model = self.move_to_cuda(self.model) + + def _dtensor_post_iter_func(tensor, dtype): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(dtype, non_blocking=True) + return tensor + + # param_iterator will return (name, tensor), we only need tensor + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + + packed_broadcast_producer( + iterator=iter(self.model.state_dict().items()), + group=self.model_update_group, + src=0, + post_iter_func=dtensor_post_iter_func, + ) + + # Manually move model to cpu for cpu offload case + # cpu offload needs model on CPU before model forward + if self.cpu_offload: + self.model = self.move_to_cpu(self.model) + + @wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_lp_inference") + def prepare_for_lp_inference(self) -> None: + # onload model to cuda + if not self.cpu_offload: + self.move_to_cuda(self.model) + else: + self.model = self.move_buffer_to_device(self.model, "cuda") + + self.model.eval() + + # offload optimizer to cpu + torch.randn(1).cuda() # wake up torch allocator + if self.optimizer is not None and self.offload_optimizer_for_logprob: + self.move_optimizer_to_device("cpu") + + gc.collect() + torch.cuda.empty_cache() + + @wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_training") + def prepare_for_training(self, *args, **kwargs) -> None: + # onload models and optimizer state to cuda + if not self.cpu_offload: + self.move_to_cuda(self.model) + else: + # when cpu offload is enabled, the buffers do not get moved + # to cuda automatically, so we need to do that manually + self.model = self.move_buffer_to_device(self.model, "cuda") + + self.model.train() + # Move optimizer state to CUDA if it exists + # colocated generation will always offload optimizer to cuda before refit + if ( + self.optimizer is not None + and not self.cpu_offload + and (self.offload_optimizer_for_logprob or self.is_generation_colocated) + ): + self.move_optimizer_to_device("cuda") + + torch.cuda.empty_cache() + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/offload_before_refit") + def offload_before_refit(self) -> None: + """Offload the optimizer to the CPU.""" + torch.randn(1).cuda() # wake up torch allocator + if self.optimizer is not None: + self.move_optimizer_to_device("cpu") + + gc.collect() + torch.cuda.empty_cache() + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker/offload_after_refit") + def offload_after_refit(self) -> None: + """Offload as much as possible on the CPU.""" + self.model = self.move_to_cpu(self.model) + self.model.eval() + torch.randn(1).cuda() # wake up torch allocator + self.offload_before_refit() # rerun the old offload function + + # Print memory stats after offloading + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print( + f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) + + def move_optimizer_to_device(self, device: str | torch.device) -> None: + for state in self.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, (DTensor, torch.Tensor)): + state[k] = v.to(device) + + def move_to_device(self, model: nn.Module, device: str | torch.device) -> nn.Module: + model = self.move_buffer_to_device(model, device) + return model.to(device) + + def move_buffer_to_device( + self, model: nn.Module, device: str | torch.device + ) -> nn.Module: + # FSDP modules do not move buffers to the device automatically + for v in model.buffers(): + v.data = v.data.to(device) + + return model + + def move_to_cuda(self, model: torch.nn.Module) -> torch.nn.Module: + model = self.move_to_device(model, "cuda") + gc.collect() + torch.cuda.empty_cache() + return model + + def move_to_cpu(self, model: torch.nn.Module) -> torch.nn.Module: + model = self.move_to_device(model, "cpu") + gc.collect() + torch.cuda.empty_cache() + return model + + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + ) -> None: + """Save a checkpoint of the model. + + the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + """ + save_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + tokenizer=self.tokenizer if tokenizer_path else None, + tokenizer_path=tokenizer_path, + ) + + def load_checkpoint( + self, weights_path: str, optimizer_path: Optional[str] = None + ) -> None: + """Load a checkpoint into the model.""" + load_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + ) + + def shutdown(self) -> None: + """Shutdown the policy.""" + # Clean up extension resources like ZMQ sockets + if hasattr(self, "zmq_socket"): + self.zmq_socket.close() + self.zmq_context.term() + + def start_gpu_profiling(self) -> None: + """Start GPU profiling.""" + torch.cuda.profiler.start() + + def stop_gpu_profiling(self) -> None: + """Stop GPU profiling.""" + torch.cuda.profiler.stop() + + def get_model_config(self): + return self.model.config.to_dict() + + def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]: + """Report the node IP and GPU ID of the current worker.""" + ip = ray._private.services.get_node_ip_address() + gpu_id = ray.get_gpu_ids()[0] + return (ip, gpu_id) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 8df5e1f15c..b3bcc86a85 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -245,6 +245,7 @@ def train( eval_mode: bool = False, gbs: Optional[int] = None, mbs: Optional[int] = None, + **kwargs, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" # Note: zero_grad_buffer is called at the start of each global batch iteration diff --git a/nemo_rl/models/policy/workers/not_working.py b/nemo_rl/models/policy/workers/not_working.py new file mode 100644 index 0000000000..d54907ba9b --- /dev/null +++ b/nemo_rl/models/policy/workers/not_working.py @@ -0,0 +1,2289 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import itertools +import os +import warnings +from collections import defaultdict +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any, Generator, Optional, cast + +import ray +import torch +from accelerate import init_empty_weights +from hydra.utils import get_class +from nemo_automodel import ( + NeMoAutoModelForSequenceClassification, +) +from nemo_automodel._transformers.registry import ModelRegistry +from nemo_automodel.components._peft.lora import ( + PeftConfig, + apply_lora_to_linear_modules, +) +from nemo_automodel.components.config.loader import _resolve_target +from nemo_automodel.components.distributed.cp_utils import ( + create_context_parallel_ctx, + get_train_context, +) +from nemo_automodel.components.distributed.fsdp2 import ( + FSDP2Manager, +) +from nemo_automodel.components.distributed.tensor_utils import ( + get_cpu_state_dict, + to_local_if_dtensor, +) +from nemo_automodel.components.moe.parallelizer import ( + parallelize_model as moe_parallelize_model, +) +from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm +from torch import nn +from torch.distributed.fsdp import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, +) +from torch.distributed.tensor import DTensor, Shard +from transformers import ( + AutoConfig, + AutoProcessor, + AutoTokenizer, + PreTrainedModel, +) +from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM + +from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + _compute_distributed_log_softmax, + allgather_cp_sharded_tensor, + distributed_vocab_topk, + get_logprobs_from_vocab_parallel_logits, +) +from nemo_rl.models.policy.utils import get_handle_from_tensor, rebuild_cuda_tensor_from_ipc +from nemo_rl.models.huggingface.common import ( + get_flash_attention_kwargs, + pack_sequences, +) +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.interfaces import ( + ColocatablePolicyInterface, + LogprobOutputSpec, + ScoreOutputSpec, +) +from nemo_rl.models.policy.utils import ( + configure_dynamo_cache, + get_runtime_env_for_policy_worker, + resolve_model_class, +) +from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker +from nemo_rl.models.policy.workers.patches import ( + apply_torch_aten_alias_tensor_patch, + apply_transformer_engine_patch, +) +from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager +from nemo_rl.utils.checkpoint import CheckpointingConfig +from nemo_rl.utils.nsys import wrap_with_nvtx_name +from nemo_rl.utils.packed_tensor import packed_broadcast_producer + +STRING_TO_DTYPE = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + + +@ray.remote( + runtime_env=get_runtime_env_for_policy_worker("dtensor_distillation_worker") +) # pragma: no cover +class DTensorDistillationWorker(AbstractPolicyWorker, ColocatablePolicyInterface): + """DTensor worker that holds both student and teacher models for distillation. + + The teacher model runs forward-only (no grad) and stores top-k logprobs in + GPU IPC buffers. The student model trains using those logprobs without any + data leaving the GPU (no Ray object store transfer). + """ + + def __repr__(self) -> str: + if torch.distributed.is_initialized(): + return f"{self.__class__.__qualname__}[rank={torch.distributed.get_rank()}]" + else: + return f"{self.__class__.__qualname__}" + + def __init__( + self, + config: PolicyConfig, + tokenizer: AutoTokenizer, + processor: Optional[AutoProcessor] = None, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, + init_optimizer: bool = True, + init_reference_model: bool = True, + teacher_config: Optional[PolicyConfig] = None, + **kwargs: Any, + ): + """Initialize the DTensorDistillationWorker with student and (optionally) teacher models.""" + # Apply TE patch until TE is upgraded to 2.10.0 + apply_transformer_engine_patch() + # Apply patch to work around 'NotImplementedError: Operator aten.alias.default does not have a sharding strategy registered' + apply_torch_aten_alias_tensor_patch() + + self.tokenizer = tokenizer + self.processor = processor + self.is_vlm = processor is not None + + print(f"Initializing DTensorPolicyWorkerV2 with is_vlm={self.is_vlm}") + + self.is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. + # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. + if not self.is_generation_colocated: + os.environ["NCCL_CUMEM_ENABLE"] = "1" + + # Disable dynamo autotune_local_cache to avoid crash when there's already a cache + # with different order of node_bundles + configure_dynamo_cache() + + self.cfg = config + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] + # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call + backend = "nccl" if not self.cpu_offload else "cuda:nccl,cpu:gloo" + torch.distributed.init_process_group(backend=backend) + self.rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + model_name = self.cfg["model_name"] + + self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] + self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] + self.max_grad_norm = self.cfg["max_grad_norm"] + + try: + self.dtype = STRING_TO_DTYPE[self.cfg["precision"]] + except KeyError: + raise ValueError(f"Unknown precision: {self.cfg['precision']}") + + self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"] + if self.enable_seq_packing: + assert not self.is_vlm, ( + "Sequence packing is not supported for VLM models. Please set policy.sequence_packing.enabled = False to train VLM models." + ) + print( + f"[Rank {self.rank}] Sequence packing is enabled for model {model_name}" + ) + print(f"[Rank {self.rank}] Using FlashAttention2 for sequence packing") + + hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {} + + # Choose attention implementation on the following basis: + # - Packed sequence requires FA2 and CP must be 1 + # - CP > 1 requires SDPA + cp_size_cfg = self.cfg["dtensor_cfg"]["context_parallel_size"] + + # NeMoAutoModelForCausalLM uses flash_attention_2 by default + # so we need to set it to None if sequence packing is disabled + # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 + attn_impl = ( + "flash_attention_2" + if (self.enable_seq_packing and cp_size_cfg == 1) + else ("sdpa" if cp_size_cfg > 1 else None) + ) + + model_config = AutoConfig.from_pretrained( + model_name, + # Always load the model in float32 to keep master weights in float32. + # Keeping the master weights in lower precision has shown to cause issues with convergence. + torch_dtype=torch.float32, + trust_remote_code=True, + attn_implementation="flash_attention_2" + if self.enable_seq_packing + else None, + **hf_config_overrides, + ) + + self.allow_flash_attn_args = self.check_model_allow_flash_attn_args( + model_config + ) + + self._is_reward_model = ( + "reward_model_cfg" in self.cfg and self.cfg["reward_model_cfg"]["enabled"] + ) + if self._is_reward_model: + # Ensure sequence packing is disabled. + if self.enable_seq_packing: + raise NotImplementedError( + "Sequence packing is not supported for reward models" + ) + # Load model as a Reward Model. + rm_type = self.cfg["reward_model_cfg"]["reward_model_type"] + if rm_type == "bradley_terry": + model_class = NeMoAutoModelForSequenceClassification + if model_config.num_labels != 1: + # For Bradley-Terry reward models, the linear head has a single output. + # In the transformers library, the default setting for model_config.num_labels is 2 + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/configuration_utils.py#L259). + # Since num_labels is used as the out_features for the linear head + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/modeling_llama.py#L738) + # if num_labels is not 1, we set it to 1. This change may trigger a warning that some weights are not initialized + # from the model checkpoint and are instead initialized using model_config.initializer_range + # (https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/llama/configuration_llama.py#L62). + print( + "model_config.num_labels is not 1. Setting it to 1 since this value is used as the out_features " + "for the linear head of Bradley-Terry reward models." + ) + model_config.num_labels = 1 + else: + raise ValueError(f"Unknown reward model type: {rm_type}") + else: + # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc. + model_class = resolve_model_class(model_config.model_type) + + # lora config + lora_cfg = self.cfg["dtensor_cfg"].get("lora_cfg", None) + self.peft_config = None + self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] + if self.lora_enabled: + if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1: + assert not lora_cfg["use_triton"], ( + "Triton is not supported when tensor_parallel_size > 1" + ) + # Always use float32 since FSDP requires all parameters to be in the same dtype. + # autocast should cast the weights to the correct dtype during the forward pass. + cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} + self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) + + print(f"[Rank {self.rank}] Initializing empty model for FSDP...") + # All ranks initialize model on meta device, so FSDP can shard it. + # The actual weights will be broadcast from rank 0. + + cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + automodel_kwargs = self.cfg["dtensor_cfg"].get("automodel_kwargs", {}) + if automodel_kwargs.get("backend", None) is not None: + backend_class = _resolve_target( + automodel_kwargs.get("backend", None)["_target_"] + ) + backend_kwargs = automodel_kwargs.get("backend") + backend_kwargs.pop("_target_") + backend = backend_class( + **backend_kwargs, + ) + automodel_kwargs["backend"] = backend + + if "use_liger_kernel" not in automodel_kwargs: + automodel_kwargs["use_liger_kernel"] = False + + with init_empty_weights(): + from torch.nn.attention import SDPBackend + + if cp_size > 1: + # Match Automodel's `get_train_context` in `cp_utils.py` where only + # flash and efficient backends are supported + # Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57 + sdpa_method = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + ] + elif self.cfg["dtensor_cfg"]["activation_checkpointing"]: + # For activation checkpointing, we must disable the cudnn SDPA backend because + # it may not be selected during recomputation. + # In that case, we will get the following error: + # "Recomputed values have different metadata than during forward pass." + sdpa_method = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + else: + sdpa_method = None + + self.model = model_class.from_pretrained( + model_name, + attn_implementation=attn_impl, + torch_dtype=str(model_config.torch_dtype), + trust_remote_code=True, + config=model_config, + sdpa_method=sdpa_method, + **automodel_kwargs, + ) + if self.lora_enabled: + apply_lora_to_linear_modules(self.model, self.peft_config) + + # For activation checkpointing, we also must globally disable the cudnn SDPA backend + # to ensure that cudnn does not get selected during recomputation. + if self.cfg["dtensor_cfg"]["activation_checkpointing"]: + from torch.backends import cuda + + cuda.enable_cudnn_sdp(False) + + # Hold a copy of model state_dict keys before any parallelization + self.model_state_dict_keys = list(self.model.state_dict().keys()) + + if self.model.config.pad_token_id is None: + self.model.config.pad_token_id = tokenizer.pad_token_id + + tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] + ep_size = self.cfg["dtensor_cfg"].get("expert_parallel_size", 1) + dp_size = None # will be inferred + if cp_size > 1 and self.enable_seq_packing: + raise ValueError( + "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) + sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"] + + if sequence_parallel_enabled and tp_size == 1: + print( + "[WARNING]: sequence_parallel=True, but tp_size=1 which has no effect. Enable tp_size > 1 to use sequence parallelism." + ) + + if cp_size > 1: + assert not isinstance(self.model, Gemma3ForCausalLM), ( + "Context parallel is not supported for Gemma3ForCausalLM. Torch context parallel has many limitations. " + "Please refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." + ) + + assert not (tp_size > 1 and sequence_parallel_enabled), ( + "It's a known issue that context parallel can't be used together with sequence parallel in DTensor worker. " + "Please either set cp_size = 1 or disable sequence parallel. " + "See https://github.com/NVIDIA-NeMo/RL/issues/659 for more details." + ) + + assert not self.is_vlm, ( + "Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models." + ) + + # ------------------------------------------------ + # Build device mesh and parallelize + # ------------------------------------------------ + manager = FSDP2Manager( + dp_size=dp_size, + dp_replicate_size=1, + tp_size=tp_size, + cp_size=cp_size, + ep_size=ep_size, + pp_size=1, + sequence_parallel=sequence_parallel_enabled, + use_hf_tp_plan=self.cfg["dtensor_cfg"].get("use_hf_tp_plan", False), + mp_policy=MixedPrecisionPolicy( + param_dtype=self.dtype, + reduce_dtype=torch.float32, + output_dtype=torch.float32, + ), + offload_policy=CPUOffloadPolicy(pin_memory=False) + if self.cpu_offload + else None, + backend="nccl", + world_size=world_size, + activation_checkpointing=self.cfg["dtensor_cfg"][ + "activation_checkpointing" + ], + custom_tp_plan=self.cfg["dtensor_cfg"].get("custom_parallel_plan", None), + ) + + # Force setup distributed for world size 1 as FSDP2Manager skips it. + if world_size == 1: + manager._setup_distributed() + + # Store mesh references for downstream usage + self.device_mesh = manager.device_mesh + self.dp_cp_mesh = self.device_mesh["dp_cp"] + self.dp_mesh = self.device_mesh["dp"] + self.tp_mesh = self.device_mesh["tp"] + self.cp_mesh = self.device_mesh["cp"] + self.moe_mesh = getattr(manager, "moe_mesh", None) + + self.dp_size = manager.dp_size + self.tp_size = manager.tp_size + self.cp_size = manager.cp_size + + # Parallelize model + is_moe_model = any(["expert" in key for key in self.model_state_dict_keys]) + is_hf_model = ( + model_config.architectures[0] not in ModelRegistry.model_arch_name_to_cls + ) + if ( + not isinstance(self.model, PreTrainedModel) + and is_moe_model + and not is_hf_model + ): + assert self.tp_size == 1, ( + "Using custom implementation {self.model.__class__.__name__} for MoE model {model_name} which doesn't support tp_size > 1. Please use expert_parallel_size > 1 for custom implementation or set force_hf=True in your config at policy->dtensor_cfg->automodel_kwargs to use the HuggingFace implementation." + ) + assert self.cp_size == 1, ( + "Using custom implementation {self.model.__class__.__name__} for MoE model {model_name} which doesn't support cp_size > 1. Please set force_hf=True in your config at policy->dtensor_cfg->automodel_kwargs to use the HuggingFace implementation." + ) + moe_parallelize_model( + model=self.model, + world_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + pp_enabled=False, + dp_axis_names=( + ("dp_replicate", "dp_shard_cp") + if "dp_replicate" in self.device_mesh.mesh_dim_names + and "dp_shard_cp" in self.device_mesh.mesh_dim_names + else ("dp_shard_cp",) + ), + cp_axis_name="cp", + tp_axis_name="tp", + ep_axis_name="ep", + ep_shard_axis_names=("ep_shard",), + ) + else: + self.model = manager.parallelize(self.model) + + # Load base model weights across all ranks using Automodel Checkpointer + # This mirrors build_model_and_optimizer's is_meta_device + load_weights path + print(self.model) + self._init_checkpoint_manager( + config_updates={ + "model_repo_id": model_name, + "dequantize_base_checkpoint": self.cfg.get( + "dequantize_base_checkpoint", False + ), + "is_peft": self.lora_enabled, + }, + ) + self.checkpoint_manager.set_model_state_dict_keys(self.model_state_dict_keys) + + # Load base HF weights unless an explicit checkpoint is provided later + # This puts shards directly into the parallelized model + self.checkpoint_manager.load_base_model( + self.model, + model_name=model_name, + hf_cache_dir=hf_config_overrides.get("cache_dir", None), + dequantize_base_checkpoint=self.cfg.get( + "dequantize_base_checkpoint", False + ), + peft_init_method=self.peft_config.lora_A_init + if self.peft_config is not None + else None, + ) + + # Handle tied word embeddings after loading the state dict + # We need to actually tie the parameters at the model level + is_tied_lm_head = hasattr(self.model, "lm_head") and getattr( + getattr(self.model, "config", {}), "tie_word_embeddings", False + ) + if is_tied_lm_head: + embed_tokens_weight = None + for name, param in self.model.named_parameters(): + if "embed_tokens" in name and name.endswith(".weight"): + embed_tokens_weight = param + break + + if embed_tokens_weight is not None: + self.model.lm_head.weight = embed_tokens_weight + print( + f"[Rank {self.rank}] lm_head weight tied: " + f"same object = {self.model.lm_head.weight is embed_tokens_weight}, " + f"embed norm = {embed_tokens_weight.data.float().norm().item():.4f}, " + f"lm_head norm = {self.model.lm_head.weight.data.float().norm().item():.4f}" + ) + else: + print(f"[Rank {self.rank}] WARNING: embed_tokens weight not found, lm_head NOT tied") + else: + print( + f"[Rank {self.rank}] lm_head tying skipped: " + f"has_lm_head={hasattr(self.model, 'lm_head')}, " + f"tie_word_embeddings={getattr(getattr(self.model, 'config', {}), 'tie_word_embeddings', 'MISSING')}" + ) + + if self.cpu_offload: + self.model = self.move_to_device(self.model, "cpu") + + if init_reference_model: + self.reference_model_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) + + if init_optimizer: + optimizer_cls = get_class(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), + **self.cfg["optimizer"]["kwargs"], + ) + else: + self.optimizer = None + + if "scheduler" in self.cfg and self.optimizer is not None: + if isinstance(self.cfg["scheduler"], dict): + scheduler_cls = get_class(cast(str, self.cfg["scheduler"]["name"])) + self.scheduler = scheduler_cls( + self.optimizer, **self.cfg["scheduler"]["kwargs"] + ) + else: + schedulers = [] + for scheduler_cfg in self.cfg["scheduler"]: + if "name" in scheduler_cfg: + schedulers.append( + get_class(scheduler_cfg["name"])( + self.optimizer, **scheduler_cfg["kwargs"] + ) + ) + else: + assert "milestones" in scheduler_cfg, ( + "unknown scheduler config: ", + scheduler_cfg, + ) + milestones: list[int] = scheduler_cfg["milestones"] + + self.scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, schedulers, milestones + ) + + elif self.optimizer is not None: + ## default to a passthrough LR schedule + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: 1 + ) + + # restore + if weights_path: + print(f"Loading weights from {weights_path}") + self.load_checkpoint(weights_path, optimizer_path) + if self.rank == 0: + for name, param in self.model.named_parameters(): + _p = param.data.float() + if torch.isnan(_p).any() or torch.isinf(_p).any(): + print(f" [NaN debug rank-0] CORRUPTED param after checkpoint load: {name}, has_nan={torch.isnan(_p).any().item()}, has_inf={torch.isinf(_p).any().item()}") + break + else: + print( + "No weights path provided. Loaded base HF weights via Checkpointer (default policy init)" + ) + + # ── Teacher model initialization (distillation-specific) ── + self.teacher_model = None + self.teacher_logits = None + self._teacher_topk_vals_buffer = None + self._teacher_topk_idx_buffer = None + self.teacher_cfg = teacher_config + + if teacher_config is not None: + teacher_model_name = teacher_config["model_name"] + teacher_hf_overrides = teacher_config.get("hf_config_overrides", {}) or {} + teacher_model_config = AutoConfig.from_pretrained( + teacher_model_name, + torch_dtype=torch.float32, + trust_remote_code=True, + **teacher_hf_overrides, + ) + teacher_model_class = resolve_model_class(teacher_model_config.model_type) + + with init_empty_weights(): + self.teacher_model = teacher_model_class.from_pretrained( + teacher_model_name, + attn_implementation=attn_impl, + torch_dtype=str(teacher_model_config.torch_dtype), + trust_remote_code=True, + config=teacher_model_config, + sdpa_method=sdpa_method, + ) + + self.teacher_model = manager.parallelize(self.teacher_model) + + teacher_ckpt_manager = AutomodelCheckpointManager( + dp_mesh=self.dp_mesh, + tp_mesh=self.tp_mesh, + model_state_dict_keys=list(self.teacher_model.state_dict().keys()), + moe_mesh=self.moe_mesh, + ) + teacher_ckpt_manager.init_checkpointer( + config_updates={ + "model_repo_id": teacher_model_name, + "dequantize_base_checkpoint": teacher_config.get( + "dequantize_base_checkpoint", False + ), + "is_peft": False, + }, + ) + teacher_ckpt_manager.set_model_state_dict_keys( + list(self.teacher_model.state_dict().keys()) + ) + teacher_ckpt_manager.load_base_model( + self.teacher_model, + model_name=teacher_model_name, + hf_cache_dir=teacher_hf_overrides.get("cache_dir", None), + dequantize_base_checkpoint=teacher_config.get( + "dequantize_base_checkpoint", False + ), + ) + self.teacher_model.eval() + for p in self.teacher_model.parameters(): + p.requires_grad_(False) + print(f"[Rank {self.rank}] Teacher model ({teacher_model_name}) loaded and parallelized") + + def _apply_temperature_scaling(self, logits: torch.Tensor, skip: bool = False) -> torch.Tensor: + if skip: + return logits + if "generation" in self.cfg and self.cfg["generation"] is not None: + temp = self.cfg["generation"]["temperature"] + if temp > 0: + logits.div_(temp) + return logits + + def check_model_allow_flash_attn_args(self, model_config) -> bool: + # Some models doesn't support flash_attn_kwargs + # Check nemotron nas. + if ( + model_config.architectures[0] == "DeciLMForCausalLM" + and model_config.model_type == "nemotron-nas" + ): + return False + + return True + + @wrap_with_nvtx_name("dtensor_distillation_worker/train") + def train( + self, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + teacher_logits: Optional[dict] = None, + ) -> dict[str, Any]: + """Train the student policy, optionally using teacher logprobs from IPC buffers.""" + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + total_dataset_size = torch.tensor(data.size, device="cuda") + torch.distributed.all_reduce( + total_dataset_size, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_mesh.get_group(), + ) + num_global_batches = int(total_dataset_size.item()) // gbs + + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + + if eval_mode: + ctx: AbstractContextManager[Any] = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + # Ensure model is in training mode + self.model.train() + + with ctx: + # If teacher_forward() was called on this same worker, read teacher + # logprobs directly from the GPU IPC buffers (self.teacher_logits). + # No data round-trip through Ray — same process, same GPU. + ipc_source = teacher_logits if teacher_logits is not None else self.teacher_logits + if ipc_source is not None and isinstance(ipc_source, dict): + current_device_id = torch.cuda.current_device() + rank = torch.distributed.get_rank() + actual_shape = ipc_source.get('actual_shape') + is_topk = ipc_source.get('is_topk', False) + + teacher_logits_tensor = rebuild_cuda_tensor_from_ipc( + ipc_source[rank], current_device_id + ).detach() + if actual_shape is not None: + aB, aS, aK = actual_shape + teacher_logits_tensor = teacher_logits_tensor[:aB, :aS, :aK].clone() + data["teacher_topk_logits"] = teacher_logits_tensor + + if is_topk and 'topk_indices_ipc' in ipc_source: + teacher_topk_indices_tensor = rebuild_cuda_tensor_from_ipc( + ipc_source['topk_indices_ipc'], current_device_id + ).detach() + if actual_shape is not None: + teacher_topk_indices_tensor = teacher_topk_indices_tensor[:aB, :aS, :aK].clone() + data["teacher_topk_indices"] = teacher_topk_indices_tensor + + # Get data from batch and move to device + data.to("cuda") + + losses = [] + all_mb_metrics = [] + for gb_idx in range(num_global_batches): + global_batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + + assert "sample_mask" in global_batch, ( + "sample_mask must be present in the data!" + ) + ## get the normalization factor for the loss + local_valid_seqs = torch.sum(global_batch["sample_mask"]) + + if "token_mask" not in global_batch: + local_valid_toks = ( + local_valid_seqs * global_batch["input_ids"].shape[1] + ) + else: + local_valid_toks = torch.sum( + global_batch["token_mask"][:, 1:] + * global_batch["sample_mask"].unsqueeze(-1) + ) + + to_reduce = torch.tensor([local_valid_seqs, local_valid_toks]).cuda() + torch.distributed.all_reduce(to_reduce, group=self.dp_mesh.get_group()) + global_valid_seqs, global_valid_toks = to_reduce[0], to_reduce[1] + + if ( + hasattr(loss_fn, "loss_type") + and loss_fn.loss_type == LossType.TOKEN_LEVEL + ): + assert "token_mask" in global_batch, ( + "token_mask must be present in the data when using token-level loss" + ) + + self.optimizer.zero_grad() + mb_losses = [] + batch = data.get_batch(batch_idx=gb_idx, batch_size=local_gbs) + # Calculate number of microbatches to process + # make_microbatch_iterator assumes that the batch size is a multiple of the microbatch size + # so its safe to not check for the case where the last data slice is smaller than mbs + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = batch.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = batch.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + iterator_len, max_seqlen = ( + batch.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = ( + batch.make_microbatch_iterator_for_packable_sequences() + ) + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = batch.make_microbatch_iterator(mbs) + iterator_len = batch.size // mbs + + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "clear_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) + + for mb_idx, mb in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() + + with torch.autocast(device_type="cuda", dtype=self.dtype): + if self.enable_seq_packing: + input_ids = mb.get("input_ids").cuda() + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=mb["input_lengths"], + packed_sequence_size=[ + len(mb["input_lengths"]) + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + min_seq_len=self.cfg["sequence_packing"][ + "train_mb_tokens" + ], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled. + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=mb["input_lengths"], + ) + + else: + input_ids = mb.get("input_ids").cuda() + batch_size, seq_len = input_ids.shape + + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} + + # add vlm kwargs to model call + vlm_kwargs = mb.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + if len(vlm_kwargs) > 0: + position_ids = None + assert not self.cfg["dtensor_cfg"]["sequence_parallel"], ( + "Sequence parallel is not supported with multimodal since there's an issue when you do not pass position_ids. See https://github.com/NVIDIA-NeMo/Automodel/issues/652" + ) + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + f"multimodal kwargs={vlm_kwargs} are not supported for context parallel" + ) + seq_index = torch.arange( + seq_len, device=input_ids.device + ).repeat(1, 1) + cp_buffers = ( + [input_ids, position_ids, seq_index] + if self.cp_size > 1 + else [] + ) + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + + if self._is_reward_model: + # `flash_attn_kwarg` is not supported for `LlamaForSequenceClassification`. + # Note that it should be empty anyway since sequence packing + # is not supported for reward models. + assert not flash_attn_kwargs + del model_args["flash_attn_kwargs"] + # remove flash_attn_kwargs if there are multimodal kwargs + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + if ( + not self.allow_flash_attn_args + and "flash_attn_kwargs" in model_args + ): + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + # Get logprobs + if isinstance(outputs, (torch.Tensor, DTensor)): + # custom models (e.g., those coming from AutoModel) can output logits directly + logits = outputs + elif not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + del outputs + + # Temperature scaling is only for inference/generation, not training + logits = self._apply_temperature_scaling(logits, skip=True) + + if self.cp_size > 1: + seq_index_dtensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + mb["seq_index"] = seq_index_dtensor + + for tensor_name in mb: + current_tensor = mb[tensor_name] + for buffer in cp_buffers: + if current_tensor is buffer: + assert type(current_tensor) == torch.Tensor, ( + f"tensor {tensor_name} is not a tensor" + ) + mb[tensor_name] = DTensor.from_local( + current_tensor, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + break + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + if self.enable_seq_packing: + loss_fn_ = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, + cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, + ) + else: + loss_fn_ = loss_fn + + # ── NaN debug: inspect logits on first microbatch of first 2 steps ── + if mb_idx == 0 and gb_idx == 0 and self.rank == 0 and len(losses) < 2: + _local_logits = logits.to_local() if isinstance(logits, DTensor) else logits + _lf = _local_logits.float() + print( + f" [NaN debug rank-0] logits shape={_local_logits.shape}, " + f"dtype={_local_logits.dtype}, " + f"min={_lf.min().item():.4f}, max={_lf.max().item():.4f}, " + f"has_nan={torch.isnan(_lf).any().item()}, " + f"has_inf={torch.isinf(_lf).any().item()}, " + f"global_valid_toks={global_valid_toks.item():.0f}, " + f"global_valid_seqs={global_valid_seqs.item():.0f}", + flush=True, + ) + del _local_logits, _lf + + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) + del logits + + # skip the update for dummy batches + if mb_idx < iterator_len: + ## scale by the number of global batches so we get the correct + ## value when summing metrics across all microbatches + for k in loss_metrics.keys(): + loss_metrics[k] /= num_global_batches + num_valid_samples = loss_metrics["num_valid_samples"] + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] + loss_metrics["global_valid_seqs"] = global_valid_seqs.item() + loss_metrics["global_valid_toks"] = global_valid_toks.item() + else: + loss *= 0 + + # Backward pass + if not eval_mode: + ## NOTE: invalid samples should be multiplied + ## by zero in the loss function to prevent them + ## from affecting the gradient calculation + + # when FSDP reduces the gradients over the DP dim, they're automatically averaged + # but we want to sum them so we cancel out the average here + loss *= self.dp_size * self.cp_size + loss.backward() + + if num_valid_samples > 0: + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + grad_norm: Optional[float | torch.Tensor] = None + if not eval_mode: + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + [self.model], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=1, + dp_group_size=self.dp_size * self.cp_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) + + # Update parameters + self.optimizer.step() + + losses.append(torch.tensor(mb_losses).sum().item()) + + # release gradient memory before rollouts + self.optimizer.zero_grad() + # increment scheduler after all batches in rollout are processed + if not eval_mode: + self.scheduler.step() + # dynamic batch and sequence dims causes alot of fragmentation, so clear + # the memory allocator before moving on + torch.cuda.empty_cache() + + # Compute global loss across all ranks + with torch.no_grad(): + global_loss = torch.tensor(losses, device="cuda") + torch.distributed.all_reduce( + global_loss, group=self.dp_mesh.get_group() + ) + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "grad_norm": grad_norm, + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": dict(mb_metrics), + } + + return metrics + + # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) + @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_logprobs") + def get_logprobs( + self, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None + ) -> BatchedDataDict[LogprobOutputSpec]: + """Get the logprobs of the model for a batch of data. + + Uses the configured logprob_batch_size to do microbatching. + + Input data is assumed to be right-padded. The method internally converts to + left-padded format for computation, and returns outputs in right-padded format. + + Returns: + a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) + + # dim 1 is always assumed to be the sequence dim, sanity check this here + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + + all_log_probs = [] + self.model.eval() + + with torch.no_grad(): + data.to("cuda") + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(logprob_batch_size) + iterator_len = data.size // logprob_batch_size + + step = 0 + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + vlm_kwargs = lp_batch.get_multimodal_dict( + as_tensors=True, device=input_ids.device + ) + + batch_size, seq_len = input_ids.shape + if self.enable_seq_packing: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for sequence packing" + ) + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create post_attention_mask for right-padded data for masking token after forwarding. + post_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.bool, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + post_attention_mask[i, :length] = 1 + + # explicitly create position ids for the input, otherwise the sharding + # for DTensor will be incorrect + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + flash_attn_kwargs = {} + + # DTensor requires the casual attention kernel to hit, + # yet our attention mask above is not always all 1s + # this is fine because we mask with the actual attention mask + # later, but for input it has to be all 1s + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + + # if there are multimodal kwargs, we don't need to add position_ids (computed internally) + if len(vlm_kwargs) > 0: + position_ids = None + + context_parallel_ctx = None + if self.cp_size > 1: + assert len(vlm_kwargs) == 0, ( + "multimodal kwargs are not supported for context parallel" + ) + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + **vlm_kwargs, + ) + if len(vlm_kwargs) > 0: + del model_args["flash_attn_kwargs"] + + if ( + not self.allow_flash_attn_args + and "flash_attn_kwargs" in model_args + ): + del model_args["flash_attn_kwargs"] + + outputs = self.model(**model_args) + + logits = outputs.logits if hasattr(outputs, "logits") else outputs + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + if self.cp_size > 1: + seq_index_tensor = ( + DTensor.from_local( + seq_index, + device_mesh=self.cp_mesh, + placements=[Shard(1)], + ) + .full_tensor() + .squeeze(0) + ) + + input_ids_dtensor = DTensor.from_local( + input_ids, + device_mesh=self.cp_mesh, + placements=[Shard(sequence_dim)], + ) + + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits, + input_ids_dtensor, + seq_index_tensor, + chunk_size=logprob_chunk_size, + ) + + assert token_logprobs.shape[1] == seq_len - 1 + else: + if isinstance(logits, DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + logits, + input_ids, + chunk_size=logprob_chunk_size, + ) + else: + if logprob_chunk_size is not None: + logits_seq_len = int(logits.shape[1]) + num_chunks = ( + logits_seq_len + logprob_chunk_size - 1 + ) // logprob_chunk_size + chunked_log_probs = [] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * logprob_chunk_size + chunk_end = min( + logits_seq_len, + (chunk_idx + 1) * logprob_chunk_size, + ) + chunk_logits = logits[ + :, chunk_start:chunk_end, : + ].to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + chunk_logits, dim=-1 + ) + chunked_log_probs.append(log_probs) + log_probs = torch.cat(chunked_log_probs, dim=1) + del chunked_log_probs + else: + logits = logits.to(torch.float32) + log_probs = torch.nn.functional.log_softmax( + logits, dim=-1 + ) + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + next_tokens = input_ids[:, 1:] + log_probs = log_probs[:, :-1] + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + del log_probs + + del outputs, logits + + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) + + # skip keeping the logprobs for the dummy batches + if batch_idx >= iterator_len: + continue + + if not self.enable_seq_packing: + # Apply mask to zero out padding tokens logprobs + token_logprobs = token_logprobs * post_attention_mask + else: + # For packed sequences, unpack logprobs + unpacked_logprobs = torch.zeros( + (batch_size, seq_dim_size), + dtype=token_logprobs.dtype, + device=token_logprobs.device, + ) + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + for i in range(batch_size): + start = cu_seqlens[i].item() + 1 + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + unpacked_logprobs[i, 1:seq_len_actual] = token_logprobs[ + 0, start:end + ] + token_logprobs = unpacked_logprobs + + all_log_probs.append(token_logprobs) + + # Concatenate all batches + return_data = BatchedDataDict[LogprobOutputSpec]() + + all_log_probs_padded = [] + for lp in all_log_probs: + padding_needed = seq_dim_size - lp.shape[1] + if padding_needed > 0: + lp = torch.nn.functional.pad( + lp, (0, padding_needed), mode="constant", value=0.0 + ) + all_log_probs_padded.append(lp) + return_data["logprobs"] = torch.cat(all_log_probs_padded, dim=0).cpu() + + return return_data + + # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) + @wrap_with_nvtx_name("dtensor_policy_worker_v2/score") + def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: + global_batch_size = min(self.cfg["batch_size"], data.size) + + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + for k, v in data.items(): + if torch.is_tensor(v) and len(v.shape) > 1: + assert v.shape[sequence_dim] == seq_dim_size, ( + f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" + ) + self.model.eval() + print("Begin to batch datas") + with torch.no_grad(): + data.to("cuda") + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(global_batch_size) + iterator_len = data.size // global_batch_size + step = 0 + all_rm_scores = [] + for batch_idx, generate_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + step += 1 + input_ids = generate_batch.get("input_ids").cuda() + input_lengths = generate_batch.get("input_lengths") + batch_size, seq_len = input_ids.shape + if self.enable_seq_packing: + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Create attention mask for right-padded data + post_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.bool, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + post_attention_mask[i, :length] = 1 + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + attention_mask = torch.ones( + (batch_size, seq_len), + dtype=torch.bool, + device=input_ids.device, + ) + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + model_args = dict( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ) + outputs = self.model(**model_args) + + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + if isinstance(logits, DTensor): + logits = logits.to(torch.float32) + else: + logits = outputs.logits.to(torch.float32) + + rm_scores = to_local_if_dtensor(logits) + rm_scores = rm_scores.squeeze(-1) + all_rm_scores.append(rm_scores) + + all_rm_scores = torch.cat(all_rm_scores, dim=0) + all_rm_scores = all_rm_scores.squeeze(-1).cpu() + return_data = BatchedDataDict[ScoreOutputSpec]( + { + "scores": all_rm_scores, + } + ) + return return_data + + @wrap_with_nvtx_name("dtensor_policy_worker_v2/get_topk_logits") + def get_topk_logits( + self, + data: BatchedDataDict[Any], + k: int, + micro_batch_size: Optional[int] = None, + ) -> BatchedDataDict[Any]: + """Return per-position top-k logits and corresponding global indices. + + Notes: + - Return shapes are [B, S, k]. + - Computes top-k over the full sequence (no trimming of the last position). + - If alignment with next-token targets is required, the caller should handle it. + - If logits are TP-sharded DTensor, performs distributed global top-k across TP. + - Supports context parallelism with proper CP gather. + - Otherwise, computes local top-k on full-vocab tensor. + """ + topk_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + + sequence_dim = 1 + seq_dim_size = data.get("input_ids").shape[sequence_dim] + + out_topk_vals = [] + out_topk_idx = [] + self.model.eval() + + with torch.no_grad(): + data.to("cuda") + dummy_iterator = iter([]) + if self.cfg["dynamic_batching"]["enabled"]: + # dynamic batching support (no CP/packed) + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + iterator_len = data.get_microbatch_iterator_dynamic_shapes_len() + elif self.enable_seq_packing: + mb_iterator = data.make_microbatch_iterator_for_packable_sequences() + iterator_len, max_seqlen = ( + data.get_microbatch_iterator_for_packable_sequences_len() + ) + max_batch_ct = torch.tensor([iterator_len], device="cuda") + torch.distributed.all_reduce( + max_batch_ct, op=torch.distributed.ReduceOp.MAX + ) + + # Sequence packing can end up with unevenly distributed batch counts across DP ranks. + # We add dummy batches to the end of the iterator to make the batch counts equal. + dummy_batch_ct = int(max_batch_ct.item() - iterator_len) + dummy_iterator = data.make_microbatch_iterator_for_packable_sequences() + dummy_iterator = itertools.islice( + itertools.cycle(dummy_iterator), dummy_batch_ct + ) + else: + mb_iterator = data.make_microbatch_iterator(topk_batch_size) + iterator_len = data.size // topk_batch_size + + for batch_idx, lp_batch in enumerate( + itertools.chain(mb_iterator, dummy_iterator) + ): + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + + batch_size, seq_len = input_ids.shape + # Store original shapes for unpacking later + original_batch_size = batch_size + original_seq_len = seq_len + + if self.enable_seq_packing: + input_ids, position_ids, _ = pack_sequences( + input_ids=input_ids, + input_lengths=input_lengths, + packed_sequence_size=[ + batch_size + ], # flash attention 2 expects flattened input + padding_value=self.tokenizer.eos_token_id, + return_attention_mask=False, + ) + seq_len = input_ids.shape[1] + attention_mask = None + flash_attn_kwargs = get_flash_attention_kwargs( + input_lengths=input_lengths, + ) + else: + # Build attention mask (right-padded inputs) + attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 + + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + flash_attn_kwargs = {} + + with torch.autocast(device_type="cuda", dtype=self.dtype): + attention_mask_input_all_ones = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + + context_parallel_ctx = None + if self.cp_size > 1: + seq_index = torch.arange(seq_len, device=input_ids.device).repeat( + 1, 1 + ) + cp_buffers = [input_ids, position_ids, seq_index] + + # Create context parallel context + context_parallel_ctx = create_context_parallel_ctx( + cp_mesh=self.cp_mesh, + cp_buffers=cp_buffers, + cp_seq_dims=[sequence_dim] * len(cp_buffers), + cp_no_restore_buffers=set(cp_buffers), + ) + + with get_train_context(False, False, context_parallel_ctx)(): + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask_input_all_ones, + position_ids=position_ids, + use_cache=False, + flash_attn_kwargs=flash_attn_kwargs, + ) + + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + del outputs + + # Apply temperature scaling + logits = self._apply_temperature_scaling(logits) + + if self.cp_size > 1: + if isinstance(logits, DTensor): + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) + + # deal with TP first + local_logits = logits.to_local() # [B, S_cp, V_tp] + + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + vals, idx = distributed_vocab_topk( + local_logits, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + # [B, S_cp, k] + + cp_group = self.cp_mesh.get_group() + + vals = allgather_cp_sharded_tensor( + vals, cp_group, seq_dim=sequence_dim + ) + idx = allgather_cp_sharded_tensor( + idx, cp_group, seq_dim=sequence_dim + ) + # [B, S, k] + else: + # Compute top-k over full sequence length (do not drop last position) + if isinstance(logits, DTensor): + local_logits = logits.to_local() # [B, S, V_local] + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + local_logits = local_logits.to(torch.float32) + local_log_probs = _compute_distributed_log_softmax(local_logits, group=tp_group) + del logits, local_logits + + if isinstance(local_log_probs, DTensor): + local_log_probs = local_log_probs.to_local() + + if self.cfg.get('is_mdlm', False): + shared_sequence_length = int(local_log_probs.shape[1] // 2) + local_log_probs = local_log_probs[:, shared_sequence_length:, :] + + vals, idx = distributed_vocab_topk( + local_log_probs, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + del local_log_probs + else: + full_logits = logits.to(torch.float32) + vals, idx = torch.topk(full_logits, k=k, dim=-1) + + # Handle sequence packing unpacking + if self.enable_seq_packing: + # Unpack top-k results from packed format back to original batch format + # vals: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] + # idx: [1, packed_seq_len, k] -> [original_batch_size, original_seq_len, k] + + # Create tensors to store unpacked results + unpacked_vals = torch.zeros( + (original_batch_size, original_seq_len, k), + dtype=vals.dtype, + device=vals.device, + ) + unpacked_idx = torch.zeros( + (original_batch_size, original_seq_len, k), + dtype=idx.dtype, + device=idx.device, + ) + + # Get cumulative sequence lengths for unpacking + cu_seqlens = flash_attn_kwargs.cu_seqlens_q + + for i in range(original_batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + seq_len_actual = input_lengths[i].item() + + # Extract the corresponding portion from packed results + # Note: vals and idx are [1, packed_seq_len, k] due to packing + unpacked_vals[i, :seq_len_actual, :] = vals[0, start:end, :] + unpacked_idx[i, :seq_len_actual, :] = idx[0, start:end, :] + + # Replace with unpacked results + vals = unpacked_vals + idx = unpacked_idx + + # Update batch_size and seq_len for consistency + batch_size = original_batch_size + seq_len = original_seq_len + + # Shapes remain [B, S, k]. + B_mb, S_mb, K_mb = vals.shape + target_dtype = vals.dtype + target_device = vals.device + + # Pre-allocate two IPC buffers (values + indices) exactly once. + if not hasattr(self, '_teacher_topk_vals_buffer') or self._teacher_topk_vals_buffer is None: + max_S = self.cfg.get("max_total_sequence_length", S_mb) + vals_buf_shape = (B_mb, max_S, K_mb) + self._teacher_topk_vals_buffer = torch.empty( + vals_buf_shape, dtype=target_dtype, device=target_device + ) + self._teacher_topk_vals_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_vals_buffer) + } + idx_buf_shape = (B_mb, max_S, K_mb) + self._teacher_topk_idx_buffer = torch.empty( + idx_buf_shape, dtype=idx.dtype, device=target_device + ) + self._teacher_topk_idx_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_idx_buffer) + } + print(f" rank {torch.distributed.get_rank()} Allocated topk IPC buffers: " + f"vals={vals_buf_shape} ({self._teacher_topk_vals_buffer.numel() * self._teacher_topk_vals_buffer.element_size() / 1e9:.4f} GB), " + f"idx={idx_buf_shape} ({self._teacher_topk_idx_buffer.numel() * self._teacher_topk_idx_buffer.element_size() / 1e9:.4f} GB) " + f"(actual data: [{B_mb}, {S_mb}, {K_mb}])") + + # Copy actual data into the top-left slice of the buffers + self._teacher_topk_vals_buffer[:B_mb, :S_mb, :K_mb].copy_(vals) + self._teacher_topk_idx_buffer[:B_mb, :S_mb, :K_mb].copy_(idx) + del vals, idx + + out_topk_vals.append(self._teacher_topk_vals_buffer[:B_mb, :S_mb, :K_mb].cpu()) + out_topk_idx.append(self._teacher_topk_idx_buffer[:B_mb, :S_mb, :K_mb].cpu()) + + ret = BatchedDataDict[Any]() + all_topk_vals_padded = [] + all_topk_idx_padded = [] + target_seq_len = seq_dim_size + for vals, idx in zip(out_topk_vals, out_topk_idx): + pad_needed = target_seq_len - vals.shape[1] + if pad_needed > 0: + vals = torch.nn.functional.pad( + vals, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0.0 + ) + idx = torch.nn.functional.pad( + idx, (0, 0, 0, pad_needed, 0, 0), mode="constant", value=0 + ) + all_topk_vals_padded.append(vals) + all_topk_idx_padded.append(idx) + + ret["topk_logits"] = ( + torch.cat(all_topk_vals_padded, dim=0) + if len(all_topk_vals_padded) > 1 + else all_topk_vals_padded[0] + ).cpu() + ret["topk_indices"] = ( + torch.cat(all_topk_idx_padded, dim=0) + if len(all_topk_idx_padded) > 1 + else all_topk_idx_padded[0] + ).cpu() + return ret + + @wrap_with_nvtx_name("dtensor_distillation_worker/teacher_forward") + def teacher_forward( + self, + data: BatchedDataDict[Any], + k: int, + micro_batch_size: Optional[int] = None, + ) -> dict: + """Run teacher forward pass and store top-k logprobs in GPU IPC buffers. + + Returns a dict with IPC handles and shape metadata (no tensor data + goes through Ray). The student's train() method uses + rebuild_cuda_tensor_from_ipc to access the same GPU memory directly. + """ + assert self.teacher_model is not None, "teacher_config must be provided to use teacher_forward" + + topk_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + + sequence_dim = 1 + self.teacher_model.eval() + + with torch.no_grad(): + data.to("cuda") + + if self.cfg["dynamic_batching"]["enabled"]: + mb_iterator = data.make_microbatch_iterator_with_dynamic_shapes() + else: + mb_iterator = data.make_microbatch_iterator(topk_batch_size) + + all_vals = [] + all_idx = [] + + for batch_idx, lp_batch in enumerate(mb_iterator): + input_ids = lp_batch.get("input_ids").cuda() + input_lengths = lp_batch.get("input_lengths") + batch_size, seq_len = input_ids.shape + + attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + attention_mask[i, :length] = 1 + + position_ids = torch.arange( + seq_len, device=input_ids.device + ).repeat(batch_size, 1) + + attention_mask_all_ones = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + + with torch.autocast(device_type="cuda", dtype=self.dtype): + outputs = self.teacher_model( + input_ids=input_ids, + attention_mask=attention_mask_all_ones, + position_ids=position_ids, + use_cache=False, + ) + + if not hasattr(outputs, "logits"): + logits = self.teacher_model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + del outputs + + if isinstance(logits, DTensor): + local_logits = logits.to_local() + tp_group = self.tp_mesh.get_group() + tp_rank = torch.distributed.get_rank(tp_group) + V_local = int(local_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + local_logits = local_logits.to(torch.float32) + local_log_probs = _compute_distributed_log_softmax(local_logits, group=tp_group) + del logits, local_logits + + if isinstance(local_log_probs, DTensor): + local_log_probs = local_log_probs.to_local() + + vals, idx = distributed_vocab_topk( + local_log_probs, + k=k, + tp_group=tp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + del local_log_probs + else: + full_logits = logits.to(torch.float32) + vals, idx = torch.topk(full_logits, k=k, dim=-1) + + all_vals.append(vals) + all_idx.append(idx) + + # Concatenate all microbatch results + topk_logprobs = torch.cat(all_vals, dim=0) if len(all_vals) > 1 else all_vals[0] + topk_indices = torch.cat(all_idx, dim=0) if len(all_idx) > 1 else all_idx[0] + + B, S, K = topk_logprobs.shape + target_dtype = topk_logprobs.dtype + target_device = topk_logprobs.device + + if self._teacher_topk_vals_buffer is None: + max_S = self.cfg.get("max_total_sequence_length", S) + vals_buf_shape = (B, max_S, K) + self._teacher_topk_vals_buffer = torch.empty( + vals_buf_shape, dtype=target_dtype, device=target_device + ) + self._teacher_topk_vals_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_vals_buffer) + } + idx_buf_shape = (B, max_S, K) + self._teacher_topk_idx_buffer = torch.empty( + idx_buf_shape, dtype=topk_indices.dtype, device=target_device + ) + self._teacher_topk_idx_ipc = { + torch.distributed.get_rank(): get_handle_from_tensor(self._teacher_topk_idx_buffer) + } + print(f" rank {torch.distributed.get_rank()} Allocated topk IPC buffers: " + f"vals={vals_buf_shape} ({self._teacher_topk_vals_buffer.numel() * self._teacher_topk_vals_buffer.element_size() / 1e9:.4f} GB), " + f"idx={idx_buf_shape} ({self._teacher_topk_idx_buffer.numel() * self._teacher_topk_idx_buffer.element_size() / 1e9:.4f} GB) " + f"(actual data: [{B}, {S}, {K}])") + + self._teacher_topk_vals_buffer[:B, :S, :K].copy_(topk_logprobs) + self._teacher_topk_idx_buffer[:B, :S, :K].copy_(topk_indices) + del topk_logprobs, topk_indices + + rank = torch.distributed.get_rank() + self.teacher_logits = { + rank: self._teacher_topk_vals_ipc[rank], + 'actual_shape': (B, S, K), + 'topk_indices_ipc': self._teacher_topk_idx_ipc[rank], + 'is_topk': True, + } + return self.teacher_logits + + @contextmanager + def use_reference_model(self) -> Generator[None, None, None]: + """Context manager that temporarily swaps the reference model and active model. + + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references + On exit: Restores original references and re-flips cuda/cpu + """ + with torch.no_grad(): + try: + # Save train model state_dict + curr_state_dict = get_cpu_state_dict( + self.model.state_dict().items(), pin_memory=True + ) + + # Swap reference model state_dict to self.model + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(self.reference_model_state_dict[k]) + + # - self.model is the original reference_model, now on CUDA + # - curr_state_dict is the train model, now on CPU + yield + + finally: + # Restore train model state_dict + for k, v in self.model.state_dict().items(): + val = to_local_if_dtensor(v) + val.copy_(curr_state_dict[k]) + + def _add_noise_to_weights(self) -> None: + """Add small Gaussian noise to the weights of the model. Note that this is used for testing purposes only.""" + noise_std = 0.01 # Standard deviation for the noise + for p in self.model.parameters(): + if p.requires_grad: + noise = torch.randn_like(p.data) * noise_std + p.data.add_(noise) # Add noise in-place + torch.cuda.synchronize() + + def return_state_dict(self): + return self.model.state_dict() + + def return_model_config(self) -> dict[str, Any]: + """Return the model configuration as a dictionary. + + Returns: + dict: Model configuration dictionary + """ + return self.model.config + + @torch.no_grad() + def prepare_refit_info(self) -> Optional[dict[str, Any]]: + """Prepare state dict metadata for weight refitting and IPC streaming.""" + state_dict_info = {} + for name, tensor in self.model.state_dict().items(): + # all tensor will be casted to self.dtype in stream_weights_via_ipc_zmq/broadcast_weights_for_collective + state_dict_info[name] = (tensor.shape, self.dtype) + + return state_dict_info + + @torch.no_grad() + def calibrate_qkv_fp8_scales( + self, + data: BatchedDataDict[Any], + micro_batch_size: Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False, + ) -> dict[str, Any]: + """Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorkerV2.""" + raise NotImplementedError( + "calibrate_qkv_fp8_scales is not implemented for DTensorPolicyWorkerV2" + ) + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq( + self, + buffer_size_bytes: int = 0, + kv_scales: Optional[dict[str, float]] = None, + ) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" + if kv_scales is not None: + raise NotImplementedError( + "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." + ) + + self.maybe_init_zmq() + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + self.model = self.move_to_cuda(self.model) + + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + def dtensor_params_generator(): + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors.""" + for name, tensor in self.model.state_dict().items(): + if isinstance(tensor, DTensor): + # Convert DTensor to full tensor for streaming + full_tensor = tensor.full_tensor() + # Convert to target dtype + yield ( + name, + full_tensor.to(self.dtype, non_blocking=True).contiguous(), + ) + else: + # Convert to target dtype + yield name, tensor.to(self.dtype, non_blocking=True).contiguous() + + # Use the shared implementation + stream_weights_via_ipc_zmq_impl( + params_generator=dtensor_params_generator(), + buffer_size_bytes=buffer_size_bytes, + zmq_socket=self.zmq_socket, + rank=self.rank, + worker_name=str(self), + ) + + @torch.no_grad() + def broadcast_weights_for_collective( + self, kv_scales: Optional[dict[str, float]] = None + ) -> None: + """Broadcast the weights for collective communication.""" + if kv_scales is not None: + raise NotImplementedError( + "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." + ) + + # Manually move model to cuda for cpu offload case + if self.cpu_offload: + print( + "[WARNING]: Unless you are lacking of memory, it is not recommended to enable cpu_offload when " + "using non-colocated generation since it will have an extra onload and offload at refit stage." + ) + self.model = self.move_to_cuda(self.model) + + def _dtensor_post_iter_func(tensor, dtype): + if isinstance(tensor, DTensor): + tensor = tensor.full_tensor() + tensor = tensor.to(dtype, non_blocking=True) + return tensor + + # param_iterator will return (name, tensor), we only need tensor + dtensor_post_iter_func = lambda x: _dtensor_post_iter_func(x[1], self.dtype) + + packed_broadcast_producer( + iterator=iter(self.model.state_dict().items()), + group=self.model_update_group, + src=0, + post_iter_func=dtensor_post_iter_func, + ) + + # Manually move model to cpu for cpu offload case + # cpu offload needs model on CPU before model forward + if self.cpu_offload: + self.model = self.move_to_cpu(self.model) + + @wrap_with_nvtx_name("dtensor_policy_worker_v2/prepare_for_lp_inference") + def prepare_for_lp_inference(self) -> None: + # onload model to cuda + if not self.cpu_offload: + self.move_to_cuda(self.model) + else: + self.model = self.move_buffer_to_device(self.model, "cuda") + + self.model.eval() + + # offload optimizer to cpu + torch.randn(1).cuda() # wake up torch allocator + if self.optimizer is not None and self.offload_optimizer_for_logprob: + self.move_optimizer_to_device("cpu") + + gc.collect() + torch.cuda.empty_cache() + + @wrap_with_nvtx_name("dtensor_policy_worker_v2/prepare_for_training") + def prepare_for_training(self, *args, **kwargs) -> None: + # onload models and optimizer state to cuda + if not self.cpu_offload: + self.move_to_cuda(self.model) + else: + # when cpu offload is enabled, the buffers do not get moved + # to cuda automatically, so we need to do that manually + self.model = self.move_buffer_to_device(self.model, "cuda") + + self.model.train() + # Move optimizer state to CUDA if it exists + # colocated generation will always offload optimizer to cuda before refit + if ( + self.optimizer is not None + and not self.cpu_offload + and (self.offload_optimizer_for_logprob or self.is_generation_colocated) + ): + self.move_optimizer_to_device("cuda") + + torch.cuda.empty_cache() + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/offload_before_refit") + def offload_before_refit(self) -> None: + """Offload the optimizer to the CPU.""" + torch.randn(1).cuda() # wake up torch allocator + if self.optimizer is not None: + self.move_optimizer_to_device("cpu") + + gc.collect() + torch.cuda.empty_cache() + + @torch.no_grad() + @wrap_with_nvtx_name("dtensor_policy_worker_v2/offload_after_refit") + def offload_after_refit(self) -> None: + """Offload as much as possible on the CPU.""" + self.model = self.move_to_cpu(self.model) + self.model.eval() + torch.randn(1).cuda() # wake up torch allocator + self.offload_before_refit() # rerun the old offload function + + # Print memory stats after offloading + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print( + f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) + + def move_optimizer_to_device(self, device: str | torch.device) -> None: + for state in self.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, (DTensor, torch.Tensor)): + state[k] = v.to(device) + + def move_to_device(self, model: nn.Module, device: str | torch.device) -> nn.Module: + model = self.move_buffer_to_device(model, device) + return model.to(device) + + def move_buffer_to_device( + self, model: nn.Module, device: str | torch.device + ) -> nn.Module: + # FSDP modules do not move buffers to the device automatically + for v in model.buffers(): + torch.utils.swap_tensors(v, v.to(device)) + + return model + + def move_to_cuda(self, model: torch.nn.Module) -> torch.nn.Module: + model = self.move_to_device(model, "cuda") + gc.collect() + torch.cuda.empty_cache() + return model + + def move_to_cpu(self, model: torch.nn.Module) -> torch.nn.Module: + model = self.move_to_device(model, "cpu") + gc.collect() + torch.cuda.empty_cache() + return model + + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, + checkpointing_cfg: Optional[CheckpointingConfig] = None, + ) -> None: + """Save a checkpoint of the model. + + the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + """ + self.checkpoint_manager.save_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer, + optimizer_path=optimizer_path, + scheduler=self.scheduler, + tokenizer=self.tokenizer if tokenizer_path is None else None, + tokenizer_path=tokenizer_path, + checkpointing_cfg=checkpointing_cfg, + lora_enabled=self.lora_enabled, + peft_config=self.peft_config, + ) + + def load_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + ) -> None: + """Load a checkpoint into the model using Automodel Checkpointer.""" + self.checkpoint_manager.load_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer, + optimizer_path=optimizer_path, + scheduler=self.scheduler, + ) + + def _init_checkpoint_manager( + self, + config_updates: Optional[dict[str, Any]] = None, + checkpoint_root: Optional[str] = None, + ) -> None: + """Initialize the AutomodelCheckpointManager for this worker. + + This creates the checkpoint manager bound to this worker's device meshes + and initializes its underlying checkpointer. + + Args: + config_updates: Dict of CheckpointingConfig fields to set during initialization. + checkpoint_root: Optional root directory for checkpoints. + """ + if self.checkpoint_manager is None: + self.checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=self.dp_mesh, + tp_mesh=self.tp_mesh, + model_state_dict_keys=getattr(self, "model_state_dict_keys", None), + moe_mesh=self.moe_mesh, + ) + self.checkpoint_manager.init_checkpointer( + config_updates=config_updates, + checkpoint_root=checkpoint_root, + ) diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 22a406dded..e09ecf2972 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -24,7 +24,12 @@ import shutil import warnings from pathlib import Path -from typing import Any, Mapping, NotRequired, Optional, TypedDict, Union +import sys +if sys.version_info >= (3, 11): + from typing import Any, Mapping, NotRequired, Optional, TypedDict, Union +else: + from typing import Any, Mapping, Optional, TypedDict, Union + from typing_extensions import NotRequired import numpy as np import torch diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index 1efe09e6d9..fc6dd307d6 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -22,7 +22,12 @@ import threading import time from abc import ABC, abstractmethod -from typing import Any, Callable, Mapping, NotRequired, Optional, TypedDict +import sys +if sys.version_info >= (3, 11): + from typing import Any, Callable, Mapping, NotRequired, Optional, TypedDict +else: + from typing import Any, Callable, Mapping, Optional, TypedDict + from typing_extensions import NotRequired import mlflow import numpy as np diff --git a/pyproject.toml b/pyproject.toml index 462d83c8a8..5544eb9ce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "nvtx", "matplotlib", "plotly", + "numba", "sympy>=1.14.0", "pillow>=12.1.1", "torchvision==0.25.0", diff --git a/x_token/cuda_tokenalign_dp.py b/x_token/cuda_tokenalign_dp.py new file mode 100644 index 0000000000..1c67eaed37 --- /dev/null +++ b/x_token/cuda_tokenalign_dp.py @@ -0,0 +1,232 @@ +""" +CUDA DP implementation that matches TokenAligner's DP transition rules. + +This module mirrors the move set used by: +`TokenAligner.align_tokens_with_combinations_numpy_jit`: + - diag (match/mismatch) + - up / left (gap) + - comb_s1_over_s2_k (10 + k) + - comb_s2_over_s1_k (20 + k) + +It also mirrors the current chunked recursion strategy: +split at midpoint when sequence length exceeds `chunk_size`. +""" + +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch +from torch.utils.cpp_extension import load + + +_INVALID = np.int64(-1) + + +@lru_cache(maxsize=1) +def _load_cuda_ext(): + src = Path(__file__).with_name("cuda_tokenalign_dp_kernel.cu") + return load( + name="cuda_tokenalign_dp_ext", + sources=[str(src)], + extra_cuda_cflags=["-O3", "--use_fast_math"], + verbose=False, + ) + + +def _build_ids_and_joined_tables( + seq1: List[str], + seq2: List[str], + max_comb_len: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + token_to_id: dict[str, int] = {} + next_id = 0 + + def get_id(s: str) -> int: + nonlocal next_id + maybe = token_to_id.get(s) + if maybe is None: + maybe = next_id + token_to_id[s] = maybe + next_id += 1 + return maybe + + n1 = len(seq1) + n2 = len(seq2) + ids1 = np.array([get_id(t) for t in seq1], dtype=np.int64) + ids2 = np.array([get_id(t) for t in seq2], dtype=np.int64) + + joined1 = np.full((n1 + 1, max_comb_len + 1), _INVALID, dtype=np.int64) + for i in range(n1 + 1): + for k in range(2, min(i, max_comb_len) + 1): + joined1[i, k] = get_id("".join(seq1[i - k : i])) + + joined2 = np.full((n2 + 1, max_comb_len + 1), _INVALID, dtype=np.int64) + for j in range(n2 + 1): + for k in range(2, min(j, max_comb_len) + 1): + joined2[j, k] = get_id("".join(seq2[j - k : j])) + + return ids1, ids2, joined1, joined2 + + +def _backtrack_from_trace( + seq1: List[str], + seq2: List[str], + trace_np: np.ndarray, +) -> List[Tuple[List[str], List[str], int, int, int, int]]: + n1 = len(seq1) + n2 = len(seq2) + aligned: List[Tuple[List[str], List[str], int, int, int, int]] = [] + + i, j = n1, n2 + while i > 0 or j > 0: + move = int(trace_np[i, j]) + if move == 1: + aligned.append(([seq1[i - 1]], [seq2[j - 1]], i - 1, i, j - 1, j)) + i -= 1 + j -= 1 + elif move == 2: + aligned.append(([seq1[i - 1]], [], i - 1, i, -1, -1)) + i -= 1 + elif move == 3: + aligned.append(([], [seq2[j - 1]], -1, -1, j - 1, j)) + j -= 1 + elif 10 <= move < 20: + k = move - 10 + aligned.append(([seq1[i - 1]], seq2[j - k : j], i - 1, i, j - k, j)) + i -= 1 + j -= k + elif 20 <= move < 30: + k = move - 20 + aligned.append((seq1[i - k : i], [seq2[j - 1]], i - k, i, j - 1, j)) + i -= k + j -= 1 + else: + break + + aligned.reverse() + return aligned + + +def align_tokens_with_combinations_cuda_chunked( + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + chunk_size: int = 128, +) -> tuple[List[Tuple[List[str], List[str], int, int, int, int]], float]: + """ + CUDA version of TokenAligner's chunked DP. + + Notes: + - Mirrors midpoint recursion in current python implementation. + - Uses CUDA only at the base chunk DP solve. + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for align_tokens_with_combinations_cuda_chunked") + + n1, n2 = len(seq1), len(seq2) + if n1 <= chunk_size and n2 <= chunk_size: + ids1, ids2, joined1, joined2 = _build_ids_and_joined_tables(seq1, seq2, max_combination_len) + device = torch.device("cuda") + ids1_t = torch.from_numpy(ids1).to(device=device, dtype=torch.int64, non_blocking=True) + ids2_t = torch.from_numpy(ids2).to(device=device, dtype=torch.int64, non_blocking=True) + j1_t = torch.from_numpy(joined1).to(device=device, dtype=torch.int64, non_blocking=True) + j2_t = torch.from_numpy(joined2).to(device=device, dtype=torch.int64, non_blocking=True) + + ext = _load_cuda_ext() + trace_t, score_t = ext.dp_chunk_cuda( + ids1_t.contiguous(), + ids2_t.contiguous(), + j1_t.contiguous(), + j2_t.contiguous(), + float(exact_match_score), + float(combination_score_multiplier), + float(gap_penalty), + int(max_combination_len), + ) + trace_np = trace_t.cpu().numpy() + aligned = _backtrack_from_trace(seq1, seq2, trace_np) + score = float(score_t.item()) + return aligned, score + + # Mirrors existing midpoint split strategy. + mid1, mid2 = n1 // 2, n2 // 2 + left_aligned, left_score = align_tokens_with_combinations_cuda_chunked( + seq1[:mid1], + seq2[:mid2], + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + chunk_size=chunk_size, + ) + right_aligned, right_score = align_tokens_with_combinations_cuda_chunked( + seq1[mid1:], + seq2[mid2:], + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + chunk_size=chunk_size, + ) + + adjusted_right = [] + for s1_tokens, s2_tokens, s1_start, s1_end, s2_start, s2_end in right_aligned: + new_s1_start = s1_start + mid1 if s1_start >= 0 else -1 + new_s1_end = s1_end + mid1 if s1_end >= 0 else -1 + new_s2_start = s2_start + mid2 if s2_start >= 0 else -1 + new_s2_end = s2_end + mid2 if s2_end >= 0 else -1 + adjusted_right.append((s1_tokens, s2_tokens, new_s1_start, new_s1_end, new_s2_start, new_s2_end)) + + return left_aligned + adjusted_right, (left_score + right_score) + + +def monkeypatch_tokenaligner_cuda_basecase() -> None: + """ + Monkeypatch TokenAligner base chunk DP with CUDA kernel-backed version. + + Usage: + from cuda_tokenalign_dp import monkeypatch_tokenaligner_cuda_basecase + monkeypatch_tokenaligner_cuda_basecase() + """ + from nemo_rl.algorithms.x_token.tokenalign import TokenAligner + + def _cuda_chunked( + seq1: List[str], + seq2: List[str], + exact_match_score: float = 3.0, + combination_score_multiplier: float = 1.5, + gap_penalty: float = -1.5, + max_combination_len: int = 4, + ignore_leading_char_diff: bool = False, + chunk_size: int = 256, + ): + if ignore_leading_char_diff: + return TokenAligner.align_tokens_combinations_chunked( + seq1=seq1, + seq2=seq2, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + ignore_leading_char_diff=ignore_leading_char_diff, + chunk_size=chunk_size, + ) + return align_tokens_with_combinations_cuda_chunked( + seq1=seq1, + seq2=seq2, + exact_match_score=exact_match_score, + combination_score_multiplier=combination_score_multiplier, + gap_penalty=gap_penalty, + max_combination_len=max_combination_len, + chunk_size=chunk_size, + ) + + TokenAligner.align_tokens_combinations_chunked = staticmethod(_cuda_chunked) + diff --git a/x_token/cuda_tokenalign_dp_kernel.cu b/x_token/cuda_tokenalign_dp_kernel.cu new file mode 100644 index 0000000000..5d66433bb5 --- /dev/null +++ b/x_token/cuda_tokenalign_dp_kernel.cu @@ -0,0 +1,166 @@ +#include +#include +#include +#include +#include + +namespace { + +__global__ void dp_chunk_kernel( + const int64_t* __restrict__ ids1, + const int64_t* __restrict__ ids2, + const int64_t* __restrict__ joined1, + const int64_t* __restrict__ joined2, + float* __restrict__ dp, + int16_t* __restrict__ trace, + int n1, + int n2, + int max_comb_len, + float exact_match_score, + float gap_penalty, + float combination_score_multiplier +) { + const int tid = threadIdx.x; + const int dp_cols = n2 + 1; + const int join_cols = max_comb_len + 1; + const int64_t invalid = static_cast(-1); + + for (int i = tid; i <= n1; i += blockDim.x) { + dp[i * dp_cols + 0] = static_cast(i) * gap_penalty; + trace[i * dp_cols + 0] = 2; // up + } + for (int j = tid; j <= n2; j += blockDim.x) { + dp[0 * dp_cols + j] = static_cast(j) * gap_penalty; + trace[0 * dp_cols + j] = 3; // left + } + if (tid == 0) { + trace[0] = 0; + } + __syncthreads(); + + for (int diag = 2; diag <= n1 + n2; ++diag) { + const int j_start = max(1, diag - n1); + const int j_end = min(n2, diag - 1); + const int cells = j_end - j_start + 1; + + for (int t = tid; t < cells; t += blockDim.x) { + const int j = j_start + t; + const int i = diag - j; + + const int64_t id_i = ids1[i - 1]; + const int64_t id_j = ids2[j - 1]; + + float best = dp[(i - 1) * dp_cols + (j - 1)] + ((id_i == id_j) ? exact_match_score : -exact_match_score); + int16_t best_move = 1; // diag + + float s_up = dp[(i - 1) * dp_cols + j] + gap_penalty; + if (s_up > best) { + best = s_up; + best_move = 2; + } + + float s_left = dp[i * dp_cols + (j - 1)] + gap_penalty; + if (s_left > best) { + best = s_left; + best_move = 3; + } + + const int k_max_s2 = min(j, max_comb_len); + for (int k = 2; k <= k_max_s2; ++k) { + const int64_t joined = joined2[j * join_cols + k]; + if (joined != invalid && id_i == joined) { + float s = dp[(i - 1) * dp_cols + (j - k)] + combination_score_multiplier * static_cast(k); + if (s > best) { + best = s; + best_move = static_cast(10 + k); // comb_s1_over_s2_k + } + } + } + + const int k_max_s1 = min(i, max_comb_len); + for (int k = 2; k <= k_max_s1; ++k) { + const int64_t joined = joined1[i * join_cols + k]; + if (joined != invalid && id_j == joined) { + float s = dp[(i - k) * dp_cols + (j - 1)] + combination_score_multiplier * static_cast(k); + if (s > best) { + best = s; + best_move = static_cast(20 + k); // comb_s2_over_s1_k + } + } + } + + dp[i * dp_cols + j] = best; + trace[i * dp_cols + j] = best_move; + } + __syncthreads(); + } +} + +inline void check_cuda(const torch::Tensor& t, const char* name) { + TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); +} + +} // namespace + +std::vector dp_chunk_cuda( + torch::Tensor ids1, + torch::Tensor ids2, + torch::Tensor joined1, + torch::Tensor joined2, + double exact_match_score, + double combination_score_multiplier, + double gap_penalty, + int64_t max_comb_len +) { + check_cuda(ids1, "ids1"); + check_cuda(ids2, "ids2"); + check_cuda(joined1, "joined1"); + check_cuda(joined2, "joined2"); + + TORCH_CHECK(ids1.dtype() == torch::kInt64, "ids1 must be int64"); + TORCH_CHECK(ids2.dtype() == torch::kInt64, "ids2 must be int64"); + TORCH_CHECK(joined1.dtype() == torch::kInt64, "joined1 must be int64"); + TORCH_CHECK(joined2.dtype() == torch::kInt64, "joined2 must be int64"); + + TORCH_CHECK(ids1.dim() == 1, "ids1 must be [n1]"); + TORCH_CHECK(ids2.dim() == 1, "ids2 must be [n2]"); + TORCH_CHECK(joined1.dim() == 2, "joined1 must be [n1+1, max_comb_len+1]"); + TORCH_CHECK(joined2.dim() == 2, "joined2 must be [n2+1, max_comb_len+1]"); + + const int n1 = static_cast(ids1.size(0)); + const int n2 = static_cast(ids2.size(0)); + TORCH_CHECK(joined1.size(0) == n1 + 1, "joined1 first dim must be n1+1"); + TORCH_CHECK(joined2.size(0) == n2 + 1, "joined2 first dim must be n2+1"); + TORCH_CHECK(joined1.size(1) == max_comb_len + 1, "joined1 second dim mismatch"); + TORCH_CHECK(joined2.size(1) == max_comb_len + 1, "joined2 second dim mismatch"); + + auto dp = torch::empty({n1 + 1, n2 + 1}, ids1.options().dtype(torch::kFloat32)); + auto trace = torch::empty({n1 + 1, n2 + 1}, ids1.options().dtype(torch::kInt16)); + + const int threads = 256; + dp_chunk_kernel<<<1, threads, 0, at::cuda::getDefaultCUDAStream()>>>( + ids1.data_ptr(), + ids2.data_ptr(), + joined1.data_ptr(), + joined2.data_ptr(), + dp.data_ptr(), + trace.data_ptr(), + n1, + n2, + static_cast(max_comb_len), + static_cast(exact_match_score), + static_cast(gap_penalty), + static_cast(combination_score_multiplier) + ); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + auto score = dp.index({n1, n2}).unsqueeze(0); + return {trace, score}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dp_chunk_cuda", &dp_chunk_cuda, "TokenAlign DP chunk CUDA"); +} + diff --git a/x_token/sanity_check_alignment_and_loss.py b/x_token/sanity_check_alignment_and_loss.py new file mode 100644 index 0000000000..3c5c6600b7 --- /dev/null +++ b/x_token/sanity_check_alignment_and_loss.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +"""Sanity check: compare alignment + loss between the standalone TokenAligner +pipeline (train_distillation_ddp.py) and the NeMo RL pipeline +(off_policy_distillation.py + loss_functions.py). + +Uses the actual code from both pipelines rather than re-implementing them: + - Path A imports TokenizeAndAlignCollator from tokenalign/src/pytorch_data_loader.py + and calls TokenAligner.compute_loss() (same as train_distillation_ddp.py) + - Path B replicates off_policy_distillation.py's decode-reencode + align flow + and calls CrossTokenizerDistillationLossFn (same as NeMo RL training) + +Usage: + python x_token/sanity_check_alignment_and_loss.py \ + --projection-matrix-path \ + --student-model meta-llama/Llama-3.2-1B \ + --teacher-model Qwen/Qwen3-8B-Base + + # With gold loss: + python x_token/sanity_check_alignment_and_loss.py \ + --projection-matrix-path \ + --student-model meta-llama/Llama-3.2-1B \ + --teacher-model Qwen/Qwen3-8B-Base \ + --gold-loss +""" +import argparse +import os +import sys +from unittest.mock import MagicMock + +# Stub out heavy dependencies that NeMo RL imports but we don't need. +# Uses a meta-path finder so *any* import under these prefixes is intercepted +# before Python's normal import machinery tries to find them on disk. +import types +import importlib +import importlib.abc +import importlib.machinery + +_STUB_PREFIXES = ( + "ray", "vllm", "uvicorn", "tensorstore", "zarr", "torchdata", + "fastapi", "starlette", "pydantic_settings", "sse_starlette", + "mlflow", "wandb", "tensorboard", + "nemo_rl.models.generation.vllm", + "nemo_rl.models.policy.lm_policy", + "nemo_rl.models.policy.hf_policy", + "nemo_rl.distributed.virtual_cluster", +) + +class _StubFinder(importlib.abc.MetaPathFinder): + def find_module(self, fullname, path=None): + if any(fullname == p or fullname.startswith(p + ".") for p in _STUB_PREFIXES): + return self + return None + + def load_module(self, fullname): + if fullname in sys.modules: + return sys.modules[fullname] + mod = _StubModule(fullname) + sys.modules[fullname] = mod + return mod + +class _StubModule(types.ModuleType): + """Module stub that returns a MagicMock for any attribute access.""" + def __init__(self, name): + super().__init__(name) + self.__spec__ = importlib.machinery.ModuleSpec(name, None) + self.__path__ = [] + self.__file__ = f"" + self.__package__ = name + self.__loader__ = None + def __getattr__(self, name): + if name.startswith("__") and name.endswith("__"): + raise AttributeError(name) + return MagicMock() + +sys.meta_path.insert(0, _StubFinder()) + +import torch + +# Make both codebases importable +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +TOKENALIGN_ROOT = os.path.join(os.path.dirname(__file__), "..", "..", "tokenalign") +sys.path.insert(0, TOKENALIGN_ROOT) +sys.path.insert(0, os.path.join(TOKENALIGN_ROOT, "src")) + + +def print_header(title: str): + print(f"\n{'=' * 60}") + print(f" {title}") + print(f"{'=' * 60}") + + +def compare_aligned_pairs(pairs_a, pairs_b, label_a="Path A", label_b="Path B", max_show=10): + """Print a summary comparison of two sets of aligned pairs.""" + for batch_idx in range(len(pairs_a)): + n_a = len(pairs_a[batch_idx]) + n_b = len(pairs_b[batch_idx]) + print(f"\n Batch {batch_idx}: {label_a}={n_a} pairs, {label_b}={n_b} pairs") + + if n_a == n_b: + diffs = 0 + for i, (pa, pb) in enumerate(zip(pairs_a[batch_idx], pairs_b[batch_idx])): + if pa[:6] != pb[:6]: + diffs += 1 + if diffs <= max_show: + print(f" Pair {i} differs:") + print(f" {label_a}: {pa[:6]}") + print(f" {label_b}: {pb[:6]}") + if diffs == 0: + print(f" All {n_a} pairs are identical (first 6 fields)") + elif diffs > max_show: + print(f" ... and {diffs - max_show} more differences") + else: + print(f" Pair counts differ -- showing first {max_show} from each:") + for i, p in enumerate(pairs_a[batch_idx][:max_show]): + print(f" {label_a}[{i}]: {p[:6]}") + for i, p in enumerate(pairs_b[batch_idx][:max_show]): + print(f" {label_b}[{i}]: {p[:6]}") + + +def main(): + parser = argparse.ArgumentParser(description="Cross-tokenizer alignment + loss sanity check") + parser.add_argument("--projection-matrix-path", type=str, required=True) + parser.add_argument("--student-model", type=str, default="meta-llama/Llama-3.2-1B") + parser.add_argument("--teacher-model", type=str, default="Qwen/Qwen3-8B-Base") + parser.add_argument("--use-sparse", action="store_true", default=False) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--vocab-topk", type=int, default=8192) + parser.add_argument("--exact-match-only", action="store_true", default=False) + parser.add_argument("--reverse-kl", action="store_true", default=False) + parser.add_argument("--gold-loss", action="store_true", default=False) + parser.add_argument("--xtoken-loss", action="store_true", default=False) + parser.add_argument("--text", type=str, default=None, + help="Sample text to use. If not provided, a default is used.") + parser.add_argument("--max-seq-len", type=int, default=128, + help="Max sequence length for tokenization (ctx_length)") + parser.add_argument("--debug-dir", type=str, default=None, + help="Path to debug dump dir (from CrossTokenKL DEBUG). " + "If provided, uses saved logits instead of random.") + parser.add_argument("--rank", type=int, default=0, help="Rank for debug dump file") + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + DEFAULT_TEXT = ( + "The quick brown fox jumps over the lazy dog. " + "Artificial intelligence is transforming how we build software. " + "Large language models learn patterns from vast amounts of text data." + ) + text = args.text or DEFAULT_TEXT + + # ================================================================ + # Setup: Tokenizers + TokenAligner (shared by both paths) + # ================================================================ + print_header("SETUP") + from transformers import AutoTokenizer + from nemo_rl.algorithms.x_token import TokenAligner + + student_tokenizer = AutoTokenizer.from_pretrained(args.student_model) + teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher_model) + if student_tokenizer.pad_token is None: + student_tokenizer.pad_token = student_tokenizer.eos_token + if teacher_tokenizer.pad_token is None: + teacher_tokenizer.pad_token = teacher_tokenizer.eos_token + + aligner = TokenAligner( + teacher_tokenizer_name=args.teacher_model, + student_tokenizer_name=args.student_model, + init_hf_tokenizers=True, + ) + aligner._load_logits_projection_map( + file_path=args.projection_matrix_path, + use_sparse_format=args.use_sparse, + device=device, + ) + aligner = aligner.to(device) + + print(f" Student model: {args.student_model}") + print(f" Teacher model: {args.teacher_model}") + print(f" Projection: {args.projection_matrix_path}") + print(f" Device: {device}") + print(f" Temperature: {args.temperature}") + print(f" Vocab top-k: {args.vocab_topk}") + print(f" Gold loss: {args.gold_loss}") + print(f" XToken loss: {args.xtoken_loss}") + print(f" Reverse KL: {args.reverse_kl}") + + # ================================================================ + # Load debug dump if provided (real logits from training) + # ================================================================ + if args.debug_dir: + print_header("LOADING DEBUG DUMP") + debug_path = os.path.join(args.debug_dir, f"debug_rank{args.rank}.pt") + print(f" Loading from {debug_path}") + dump = torch.load(debug_path, map_location="cpu", weights_only=False) + student_logits = dump["student_logits"].to(device) + teacher_logits = dump["teacher_logits"].to(device) + print(f" student_logits: {student_logits.shape}") + print(f" teacher_logits: {teacher_logits.shape}") + else: + student_logits = None + teacher_logits = None + + # ================================================================ + # PATH A: train_distillation_ddp.py pipeline + # + # Uses TokenizeAndAlignCollator from tokenalign/src/pytorch_data_loader.py + # which is what TorchDataLoaderXToken uses internally. + # Then calls TokenAligner.compute_loss() for the loss. + # ================================================================ + print_header("PATH A: train_distillation_ddp pipeline") + print(" (TokenizeAndAlignCollator -> TokenAligner.compute_loss)") + + from pytorch_data_loader import TokenizeAndAlignCollator + + # Build the collator exactly as TorchDataLoaderXToken does (line 290-301) + collator_a = TokenizeAndAlignCollator( + tokenizer_student=student_tokenizer, + tokenizer_teacher=teacher_tokenizer, + token_aligner=aligner, + ctx_length=args.max_seq_len, + chunk_size=64, + same_vocab=False, + characters_per_sample=None, + align_convert_to_tokens=True, + text_key="text", + ) + + # Feed the text as a batch of one sample (same as dataloader would) + input_ids_student_a, input_ids_teacher_a, aligned_pairs_a = collator_a( + [{"text": text}] + ) + input_ids_student_a = input_ids_student_a.to(device) + input_ids_teacher_a = input_ids_teacher_a.to(device) + batch_size = input_ids_student_a.shape[0] + + print(f" Student tokens: {input_ids_student_a.shape}") + print(f" Teacher tokens: {input_ids_teacher_a.shape}") + print(f" Aligned pairs: {sum(len(ap) for ap in aligned_pairs_a)} total") + + # Generate or reuse logits + if student_logits is not None: + student_logits_a = student_logits + teacher_logits_a = teacher_logits + else: + from transformers import AutoConfig + s_vocab = AutoConfig.from_pretrained(args.student_model).vocab_size + t_vocab = AutoConfig.from_pretrained(args.teacher_model).vocab_size + torch.manual_seed(42) + student_logits_a = torch.randn( + batch_size, input_ids_student_a.shape[1], s_vocab, + device=device, dtype=torch.float32, + ) + teacher_logits_a = torch.randn( + batch_size, input_ids_teacher_a.shape[1], t_vocab, + device=device, dtype=torch.float32, + ) + + # Loss: TokenAligner.compute_loss (same call as train_distillation_ddp.py lines 1698-1713) + with torch.no_grad(): + loss_a, acc_a = aligner.compute_loss( + aligned_pairs=aligned_pairs_a, + student_logits=student_logits_a, + teacher_logits=teacher_logits_a, + input_ids_student=input_ids_student_a, + input_ids_teacher=input_ids_teacher_a, + loss_type="KL", + exact_token_match_only=args.exact_match_only, + temperature=args.temperature, + vocab_topk=args.vocab_topk, + reverse_kl=args.reverse_kl, + gold_loss=args.gold_loss, + xtoken_loss=args.xtoken_loss, + ) + + print(f"\n TokenAligner.compute_loss result:") + print(f" Loss: {loss_a.item():.6f}") + print(f" Top1 Acc: {acc_a:.4f}") + + # ================================================================ + # PATH B: off_policy_distillation.py pipeline + # + # Replicates the cross-tokenizer processing from + # off_policy_distillation.py lines 784-815: + # 1. Decode student tokens -> text + # 2. Re-encode with teacher tokenizer + # 3. Align + # 4. Call CrossTokenizerDistillationLossFn + # ================================================================ + print_header("PATH B: off_policy_distillation pipeline") + print(" (decode-reencode -> CrossTokenizerDistillationLossFn)") + + # --- Step 1-2: off_policy_distillation.py lines 785-806 --- + # Use student IDs from Path A as the starting point + # (in real training, these come from the NeMo RL data pipeline) + student_ids = input_ids_student_a + batch_size_ct = student_ids.shape[0] + + # off_policy_distillation.py line 788-791 + texts_b = [ + student_tokenizer.decode(student_ids[i].cpu().tolist(), skip_special_tokens=True) + for i in range(batch_size_ct) + ] + + # off_policy_distillation.py lines 797-806 + max_teacher_len = args.max_seq_len + teacher_encoded = teacher_tokenizer( + texts_b, + max_length=max_teacher_len, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + input_ids_teacher_b = teacher_encoded["input_ids"].to(device) + teacher_attention_mask = teacher_encoded["attention_mask"] + teacher_input_lengths_ct = teacher_attention_mask.sum(dim=1) + + # off_policy_distillation.py lines 808-810 + aligned_pairs_b = aligner.align( + student_ids, input_ids_teacher_b + ) + + input_ids_student_b = student_ids + + print(f" Student tokens: {input_ids_student_b.shape}") + print(f" Teacher tokens: {input_ids_teacher_b.shape}") + print(f" Aligned pairs: {sum(len(ap) for ap in aligned_pairs_b)} total") + + # Generate logits for Path B (must match teacher seq len) + if student_logits is not None: + student_logits_b = student_logits + teacher_logits_b = teacher_logits + else: + torch.manual_seed(42) + student_logits_b = torch.randn( + batch_size, input_ids_student_b.shape[1], s_vocab, + device=device, dtype=torch.float32, + ) + teacher_logits_b = torch.randn( + batch_size, input_ids_teacher_b.shape[1], t_vocab, + device=device, dtype=torch.float32, + ) + + # --- Step 3: Loss via CrossTokenizerDistillationLossFn --- + # off_policy_distillation.py lines 812-815 + 878-881 + from nemo_rl.algorithms.loss_functions import CrossTokenizerDistillationLossFn + + cfg = { + "loss_type": "KL", + "temperature": args.temperature, + "vocab_topk": args.vocab_topk, + "exact_token_match_only": args.exact_match_only, + "reverse_kl": args.reverse_kl, + "gold_loss": args.gold_loss, + "xtoken_loss": args.xtoken_loss, + } + loss_fn = CrossTokenizerDistillationLossFn(cfg, aligner) + loss_fn._debug_dumped = True # skip debug dump + + # off_policy_distillation.py lines 812-815 + loss_fn.set_cross_tokenizer_data( + teacher_input_ids=input_ids_teacher_b, + aligned_pairs=aligned_pairs_b, + ) + + student_seq_len_b = student_logits_b.shape[1] + nemo_data = { + "input_ids": input_ids_student_b, + "input_lengths": torch.tensor([student_seq_len_b] * batch_size, device=device), + "token_mask": torch.ones(batch_size, student_seq_len_b, device=device), + "sample_mask": torch.ones(batch_size, device=device), + } + global_valid_toks = torch.tensor(float(student_seq_len_b * batch_size), device=device) + global_valid_seqs = torch.tensor(float(batch_size), device=device) + + with torch.no_grad(): + loss_b_scaled, metrics_b = loss_fn( + next_token_logits=student_logits_b, + data=nemo_data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + teacher_logits=teacher_logits_b, + mb_idx=None, + mbs=None, + ) + + # Undo NeMo RL distributed scaling to get raw loss + local_valid = (student_seq_len_b - 1) * batch_size + raw_loss_b = ( + loss_b_scaled.item() * float(global_valid_toks) / local_valid + if local_valid > 0 else 0.0 + ) + + print(f"\n CrossTokenizerDistillationLossFn result:") + print(f" Loss (NeMo RL scaled): {loss_b_scaled.item():.6f}") + print(f" Loss (raw): {raw_loss_b:.6f}") + print(f" Metrics: {metrics_b}") + + # ================================================================ + # PATH B-REF: TokenAligner.compute_loss on Path B's alignment + # (isolates loss implementation difference from alignment difference) + # ================================================================ + print_header("PATH B-REF: TokenAligner.compute_loss on Path B alignment") + + with torch.no_grad(): + loss_b_ref, acc_b_ref = aligner.compute_loss( + aligned_pairs=aligned_pairs_b, + student_logits=student_logits_b, + teacher_logits=teacher_logits_b, + input_ids_student=input_ids_student_b, + input_ids_teacher=input_ids_teacher_b, + loss_type="KL", + exact_token_match_only=args.exact_match_only, + temperature=args.temperature, + vocab_topk=args.vocab_topk, + reverse_kl=args.reverse_kl, + gold_loss=args.gold_loss, + xtoken_loss=args.xtoken_loss, + ) + print(f" Loss: {loss_b_ref.item():.6f}") + print(f" Top1 Acc: {acc_b_ref:.4f}") + + # ================================================================ + # ALIGNMENT COMPARISON + # ================================================================ + print_header("ALIGNMENT COMPARISON: Path A vs Path B") + + teacher_ids_match = torch.equal(input_ids_teacher_a, input_ids_teacher_b) + print(f" Teacher token IDs identical: {teacher_ids_match}") + if not teacher_ids_match: + for i in range(batch_size): + mask = input_ids_teacher_a[i] != input_ids_teacher_b[i] + n_diff = mask.sum().item() + total = input_ids_teacher_a.shape[1] + print(f" Batch {i}: {n_diff}/{total} tokens differ") + if n_diff > 0 and n_diff <= 20: + diff_positions = torch.where(mask)[0].tolist() + for pos in diff_positions[:10]: + a_tok = teacher_tokenizer.decode([input_ids_teacher_a[i, pos].item()]) + b_tok = teacher_tokenizer.decode([input_ids_teacher_b[i, pos].item()]) + print(f" pos {pos}: A='{a_tok}' ({input_ids_teacher_a[i, pos].item()}) " + f"vs B='{b_tok}' ({input_ids_teacher_b[i, pos].item()})") + + compare_aligned_pairs( + aligned_pairs_a, aligned_pairs_b, + label_a="Path A (train_distillation_ddp)", + label_b="Path B (off_policy_distillation)", + ) + + # ================================================================ + # LOSS COMPARISON + # ================================================================ + print_header("LOSS COMPARISON") + + loss_a_val = loss_a.item() + loss_b_ref_val = loss_b_ref.item() + + print(f" Path A (TokenAligner.compute_loss, train_distillation_ddp pipeline):") + print(f" Loss = {loss_a_val:.6f}, Acc = {acc_a:.4f}") + print(f" Path B (CrossTokenizerDistillationLossFn, off_policy_distillation pipeline):") + print(f" Raw loss = {raw_loss_b:.6f}") + print(f" Path B-REF (TokenAligner.compute_loss, off_policy_distillation alignment):") + print(f" Loss = {loss_b_ref_val:.6f}, Acc = {acc_b_ref:.4f}") + + # --- Key comparison 1: alignment source --- + print(f"\n [1] Alignment difference (same loss fn, different tokenization):") + diff_alignment = abs(loss_a_val - loss_b_ref_val) + print(f" |Path A - Path B-REF| = {diff_alignment:.6f}") + if loss_a_val > 0: + print(f" Relative: {diff_alignment / loss_a_val * 100:.2f}%") + + # --- Key comparison 2: loss implementation --- + print(f"\n [2] Implementation difference (same alignment, different loss fn):") + diff_impl = abs(loss_b_ref_val - raw_loss_b) + print(f" |Path B-REF - Path B| = {diff_impl:.6f}") + if loss_b_ref_val > 0: + print(f" Relative: {diff_impl / loss_b_ref_val * 100:.2f}%") + + # --- Verdict --- + print(f"\n --- Verdict ---") + if diff_impl < 0.01: + print(f" LOSS MATCH -- CrossTokenizerDistillationLossFn matches TokenAligner.compute_loss") + elif diff_impl < 0.1: + print(f" ~ CLOSE -- small numerical differences between loss implementations") + else: + print(f" MISMATCH -- significant loss implementation difference, investigate") + + if diff_alignment > 0.01 and not teacher_ids_match: + print(f"\n NOTE: Alignment differs because off_policy_distillation decodes student") + print(f" tokens then re-encodes with teacher tokenizer, while train_distillation_ddp") + print(f" tokenizes the raw text independently. This is expected.") + elif diff_alignment < 0.001: + print(f"\n ALIGNMENT MATCH -- both pipelines produce same teacher tokens and alignment") + + +if __name__ == "__main__": + main() diff --git a/x_token/sanity_check_loss.py b/x_token/sanity_check_loss.py new file mode 100644 index 0000000000..3810895fde --- /dev/null +++ b/x_token/sanity_check_loss.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +"""Sanity check: compare loss from the actual CrossTokenizerDistillationLossFn +in loss_functions.py with the original TokenAligner.compute_loss(). + +Usage: + python x_token/sanity_check_loss.py \ + --debug-dir x_token/debug_dump \ + --projection-matrix-path cross_tokenizer_data/projection_map_Llama-3.2_to_Qwen3_multitoken_top_32_double_special.pt \ + --student-model meta-llama/Llama-3.2-1B \ + --teacher-model Qwen/Qwen3-8B-Base +""" +import argparse +import os +import sys + +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--debug-dir", type=str, default="/tmp/cross_tok_debug") + parser.add_argument("--rank", type=int, default=0) + parser.add_argument("--projection-matrix-path", type=str, required=True) + parser.add_argument("--student-model", type=str, default="meta-llama/Llama-3.2-1B") + parser.add_argument("--teacher-model", type=str, default="Qwen/Qwen3-8B-Base") + parser.add_argument("--use-sparse", action="store_true", default=False) + args = parser.parse_args() + + debug_path = os.path.join(args.debug_dir, f"debug_rank{args.rank}.pt") + print(f"Loading debug tensors from {debug_path}") + data = torch.load(debug_path, map_location="cpu", weights_only=False) + + student_logits = data["student_logits"] + teacher_logits = data["teacher_logits"] + input_ids_student = data["input_ids_student"] + input_ids_teacher = data["input_ids_teacher"] + aligned_pairs = data["aligned_pairs"] + cfg = data["config"] + + print(f"\nShapes:") + print(f" student_logits: {student_logits.shape}") + print(f" teacher_logits: {teacher_logits.shape}") + print(f" input_ids_student: {input_ids_student.shape}") + print(f" input_ids_teacher: {input_ids_teacher.shape}") + print(f" aligned_pairs: {len(aligned_pairs)} batches, " + f"{sum(len(ap) for ap in aligned_pairs)} total pairs") + print(f"\nConfig: {cfg}") + + device = "cuda" if torch.cuda.is_available() else "cpu" + student_logits = student_logits.to(device) + teacher_logits = teacher_logits.to(device) + input_ids_student = input_ids_student.to(device) + input_ids_teacher = input_ids_teacher.to(device) + + batch_size = student_logits.shape[0] + student_seq_len = student_logits.shape[1] + + # ---- Build the TokenAligner (shared by both paths) ---- + from nemo_rl.algorithms.x_token import TokenAligner + + aligner = TokenAligner( + teacher_tokenizer_name=args.teacher_model, + student_tokenizer_name=args.student_model, + init_hf_tokenizers=True, + ) + aligner._load_logits_projection_map( + file_path=args.projection_matrix_path, + use_sparse_format=args.use_sparse, + device=device, + ) + aligner = aligner.to(device) + + temperature = cfg.get("temperature", 1.0) + vocab_topk = cfg.get("vocab_topk", 8192) + exact_match = cfg.get("exact_token_match_only", False) + reverse_kl = cfg.get("reverse_kl", False) + + print(f"\n temperature={temperature}, vocab_topk={vocab_topk}, " + f"exact_match={exact_match}, reverse_kl={reverse_kl}") + + # ---- Path A: Original TokenAligner.compute_loss() ---- + print("\n" + "=" * 60) + print(" Path A: Original TokenAligner.compute_loss()") + print("=" * 60) + + with torch.no_grad(): + orig_loss, orig_acc = aligner.compute_loss( + aligned_pairs=aligned_pairs, + student_logits=student_logits, + teacher_logits=teacher_logits, + input_ids_student=input_ids_student, + input_ids_teacher=input_ids_teacher, + loss_type=cfg.get("loss_type", "KL"), + exact_token_match_only=exact_match, + temperature=temperature, + vocab_topk=vocab_topk, + reverse_kl=reverse_kl, + ) + print(f"\n Original loss: {orig_loss.item():.6f}") + print(f" Original topk_acc: {orig_acc:.4f}") + + # ---- Path B: Actual CrossTokenizerDistillationLossFn from loss_functions.py ---- + print("\n" + "=" * 60) + print(" Path B: CrossTokenizerDistillationLossFn (loss_functions.py)") + print("=" * 60) + + from nemo_rl.algorithms.loss_functions import CrossTokenizerDistillationLossFn + + loss_fn = CrossTokenizerDistillationLossFn(cfg, aligner) + loss_fn._debug_dumped = True # skip re-dumping + + loss_fn.set_cross_tokenizer_data( + teacher_input_ids=input_ids_teacher, + aligned_pairs=aligned_pairs, + ) + + # Build the NeMo RL data dict that __call__ expects + nemo_data = { + "input_ids": input_ids_student, + "input_lengths": torch.tensor([student_seq_len] * batch_size, device=device), + "token_mask": torch.ones(batch_size, student_seq_len, device=device), + "sample_mask": torch.ones(batch_size, device=device), + } + + # global_valid_toks/seqs: in real training these are summed across ranks. + # For single-rank comparison, set them equal to local counts so the + # distributed scaling becomes: loss * local / global = loss * 1.0 + global_valid_toks = torch.tensor(float(student_seq_len * batch_size), device=device) + global_valid_seqs = torch.tensor(float(batch_size), device=device) + + with torch.no_grad(): + our_loss, our_metrics = loss_fn( + next_token_logits=student_logits, + data=nemo_data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + teacher_logits=teacher_logits, + mb_idx=None, + mbs=None, + ) + + # Undo the distributed scaling to get raw chunk loss: + # loss_fn does: loss = raw_loss * local_valid_toks / global_valid_toks + # With token_mask=all-ones: local_valid_toks = (student_seq_len - 1) * batch_size + # (the -1 is from token_mask[:, 1:max_len+1]) + local_valid = (student_seq_len - 1) * batch_size + raw_our_loss = our_loss.item() * float(global_valid_toks) / local_valid if local_valid > 0 else 0.0 + + print(f"\n Our loss (after NeMo RL scaling): {our_loss.item():.6f}") + print(f" Our loss (raw, before scaling): {raw_our_loss:.6f}") + print(f" Metrics: {our_metrics}") + + # ---- Comparison ---- + print("\n" + "=" * 60) + print(" COMPARISON") + print("=" * 60) + print(f" Original TokenAligner loss (raw): {orig_loss.item():.6f}") + print(f" Our loss (raw, before scaling): {raw_our_loss:.6f}") + diff = abs(orig_loss.item() - raw_our_loss) + print(f" Absolute difference: {diff:.6f}") + if orig_loss.item() > 0: + print(f" Relative difference: {diff / orig_loss.item() * 100:.2f}%") + + if diff < 0.01: + print("\n MATCH — losses are essentially identical") + elif diff < 0.1: + print("\n ~ CLOSE — small numerical differences (likely from filtering)") + else: + print("\n MISMATCH — significant difference, investigate further") + + +if __name__ == "__main__": + main()