Skip to content

Xtoken/off policy distillation gh#2245

Open
avenkateshha wants to merge 17 commits intoNVIDIA-NeMo:mainfrom
avenkateshha:xtoken/off-policy-distillation-gh
Open

Xtoken/off policy distillation gh#2245
avenkateshha wants to merge 17 commits intoNVIDIA-NeMo:mainfrom
avenkateshha:xtoken/off-policy-distillation-gh

Conversation

@avenkateshha
Copy link
Copy Markdown

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Adithyakrishna Hanasoge added 17 commits April 9, 2026 17:46
- Off-policy distillation pipeline (teacher Llama-3.1-8B, student Llama-3.2-1B)
  with arrow dataset support and inline MATH/MMLU generation-based evaluation
- Compute distributed log_softmax before top-k for correct KL divergence
- Add CUDA IPC buffer mechanism to avoid Ray object store bottleneck
  for large top-k logprob tensors (based on dtensor_sharath.py approach)
- Update loss function to skip re-normalization of teacher log probabilities
- Add submit scripts, configs, and eval benchmarks

Made-with: Cursor
Made-with: Cursor
- Flatten teacher IPC data structure and use mb_idx*mbs indexing instead
  of cumulative mb_offset/mb_size for microbatch slicing
- Add log_softmax for teacher top-k logits in standard (non-IPC) path
- Restore output_replicated for teacher path and add kl_loss/nll_loss
  to aggregated results
- Make arrow config self-contained with explicit settings

Made-with: Cursor
Made-with: Cursor
Refactor teacher logit sharing to use per-microbatch IPC buffers instead
of accumulating all teacher logits post-loop. Update loss functions to
handle optional microbatch indexing. Bump config to TP=4 and 10k steps.

Made-with: Cursor
- Add `use_ipc` config flag to switch between IPC (in-process
  communication) and non-IPC (data-dict) teacher logprob paths
- Simplify KL loss to use (k+1)-dim distributions with a "rest"
  bucket, unifying IPC and non-IPC code paths
- Branch teacher inference and student training for both train
  and validation loops based on use_ipc setting
- Update submit script with IPC test experiment config

Made-with: Cursor
- Add x_token/ module with TokenAligner for cross-vocabulary distillation
- Rewrite CrossTokenizerDistillationLossFn with chunk-averaged KL that
  handles 1:1, 1:many, many:1, and many:many token alignments, matching
  the original TokenAligner.compute_KL_loss_optimized() exactly
  (verified via sanity check: 0.00% difference)
- Fix teacher IPC to send full logits (topk_logits=None) instead of
  topk_logits=0 which produced empty tensors
- Pass global_top_indices through to projection for memory optimization
  (2.3GB -> 125MB for projection output tensor)
- Add cross-tokenizer data processing in training loop (dual tokenization,
  alignment, teacher data dict)
- Add unbuffered stdout for better SLURM log visibility
- Add example config and sanity check script

Made-with: Cursor
Documents architecture, all new/modified files, usage instructions,
configuration reference, and design decisions for the cross-tokenizer
off-policy distillation feature built on NeMo RL v0.5.0.

Made-with: Cursor
…compat

- Switch default teacher from Qwen3-8B to Phi-4-mini-instruct
- Add gold loss (common-vocab KL + uncommon-vocab L1) and xtoken loss modes
- Add CE loss with dynamic loss scaling option
- Replace dense projection with CSR sparse matmul for memory efficiency
- Add MMLU 5-shot evaluation benchmark
- Fix NotRequired import for Python <3.11 compatibility (16 files)
- Add submit_cross_tokenizer.sh sbatch script
- Add sanity check script for alignment and loss verification
- Update LR schedule for 80k step training (warmup 4k, cosine 76k)

Made-with: Cursor
…parse projection

- Cache CrossTokenizerDistillationLossFn on policy workers at init and
  pass None to train() calls, eliminating repeated Ray serialization of
  the loss function (which includes large sparse matrices) each step.
- Add set_loss_fn() and update_cross_tokenizer_data() to Policy and
  DTensorPolicyWorkerV2 to support per-step cross-tokenizer data updates.
- Optimize sparse token projection by pre-reducing the sparse matrix
  with index_select before projection instead of projecting full vocab
  and slicing afterward.
- Use AutoConfig.from_pretrained() for vocab sizes in sanity check script.

Made-with: Cursor
…P rank

Reduced training time with this optimization. Avoid Ray serialization of
the loss function by having each worker construct
CrossTokenizerDistillationLossFn from config + shared filesystem. Shard
teacher_input_ids and aligned_pairs per data-parallel rank instead of
broadcasting the full batch to every worker.

Made-with: Cursor
- Add O(n+m) character-offset alignment via two-pointer walk on tokenizer
  offset mappings, with automatic DP fallback for failed samples
- Precompute canonical token ID maps at startup to skip convert_ids_to_tokens
- Add Numba JIT-accelerated DP kernel and banded DP variant
- Add KD preprocessor preserving raw text for teacher tokenization
- Add numba dependency
- Update config: expand arrow data glob, set max_num_epochs=1
- Update submit script: bump max_num_steps=10, rename experiment to raw-text-kd-16node

Made-with: Cursor
Introduce CUDA kernel and Python integration module for faster TokenAligner dynamic programming base-case computation.

Made-with: Cursor
… preprocessing with current-step GPU training while keeping alignment behavior unchanged.

Add explicit token-aligner runtime switches (`use_char_offset`, `use_align_fast`, CUDA-DP toggles), clean up dead/duplicated paths, and simplify the step orchestration with typed prefetch payloads and helper extraction for maintainability.

Made-with: Cursor
Set explicit token aligner defaults and document the total-GPUs/2 heuristic for cross_tokenizer_num_workers so large-batch off-policy runs can iterate on stable, reproducible CT pool sizing.

Made-with: Cursor
@avenkateshha avenkateshha requested review from a team as code owners April 10, 2026 00:49
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants