diff --git a/.github/workflows/integration-tests-claude.yml b/.github/workflows/integration-tests-claude.yml new file mode 100644 index 0000000000..75f8e2adda --- /dev/null +++ b/.github/workflows/integration-tests-claude.yml @@ -0,0 +1,29 @@ +name: Claude Integration Tests +on: + schedule: + - cron: "0 6 * * 1" # Weekly Monday 6am UTC + workflow_dispatch: + push: + paths: + - "bionemo-recipes/claude-plugin/**" + - "bionemo-recipes/integration-tests/**" + +jobs: + test: + runs-on: linux-amd64-gpu-l4-latest-1 + container: + image: nvcr.io/nvidia/pytorch:25.06-py3 + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + steps: + - uses: actions/checkout@v4 + + - name: Install Claude Code CLI + run: npm install -g @anthropic-ai/claude-code + + - name: Install test dependencies + run: pip install pytest pytest-timeout + + - name: Run integration tests + run: cd bionemo-recipes/integration-tests && pytest -v --timeout=600 + timeout-minutes: 30 diff --git a/bionemo-recipes/claude-plugin/.claude-plugin/plugin.json b/bionemo-recipes/claude-plugin/.claude-plugin/plugin.json new file mode 100644 index 0000000000..6d8900dbfa --- /dev/null +++ b/bionemo-recipes/claude-plugin/.claude-plugin/plugin.json @@ -0,0 +1,9 @@ +{ + "name": "bionemo-recipes", + "version": "0.1.0", + "description": "Convert HuggingFace models to TransformerEngine, add FP8 support, set up distributed training — using NVIDIA BioNeMo Recipes as reference.", + "author": { "name": "NVIDIA BioNeMo Team" }, + "repository": "https://github.com/NVIDIA/bionemo-framework", + "license": "Apache-2.0", + "keywords": ["transformerengine", "fp8", "fsdp", "distributed-training", "nvidia"] +} diff --git a/bionemo-recipes/claude-plugin/README.md b/bionemo-recipes/claude-plugin/README.md new file mode 100644 index 0000000000..2db050ce3e --- /dev/null +++ b/bionemo-recipes/claude-plugin/README.md @@ -0,0 +1,73 @@ +# BioNeMo Recipes Claude Plugin + +A Claude Code plugin for converting HuggingFace models to NVIDIA TransformerEngine, +adding FP8/FP4 quantization support, writing golden value tests, and setting up +FSDP distributed training. All skills use real BioNeMo Recipes as reference implementations. + +## Installation + +```bash +claude --add-dir /path/to/bionemo-recipes/claude-plugin +``` + +## Available Skills + +| Skill | Description | +| ---------------------- | -------------------------------------------------------------------------------------------------------------------------- | +| `/te-convert-model` | Convert a HuggingFace `PreTrainedModel` to use TransformerEngine layers with bidirectional weight conversion (HF \<-> TE). | +| `/add-fp8-support` | Add FP8 or FP4 quantized training support to an existing TransformerEngine model. | +| `/write-golden-tests` | Generate golden value tests that verify a TE model produces identical outputs to the original HF reference model. | +| `/setup-fsdp-training` | Scaffold a complete FSDP training recipe with Hydra configs, distributed launcher, and Docker environment. | +| `/export-to-hf-hub` | Create an export script that bundles model weights, tokenizer, and config for publishing to the Hugging Face Hub. | + +## Usage Examples + +### Convert a HuggingFace model to TransformerEngine + +``` +/te-convert-model facebook/esm2_t33_650M_UR50D +``` + +Generates a TE-backed `PreTrainedModel` class with `convert_hf_to_te()` and +`convert_te_to_hf()` functions, following the pattern in `bionemo-recipes/models/`. + +### Add FP8 quantized training + +``` +/add-fp8-support --precision fp8 +``` + +Adds FP8 recipe configuration, `DelayedScaling` setup, and the `fp8_autocast` +context manager to your training loop. + +### Write golden value tests + +``` +/write-golden-tests --model esm2 --reference facebook/esm2_t33_650M_UR50D +``` + +Creates pytest tests that load both the HF reference and TE model, run a forward +pass with fixed inputs, and assert outputs match within tolerance. + +### Set up FSDP distributed training + +``` +/setup-fsdp-training --model esm2 --framework native_te +``` + +Scaffolds a self-contained recipe directory with a Dockerfile, training script, +Hydra configs, and a sample data loader. + +### Export model to Hugging Face Hub + +``` +/export-to-hf-hub --model esm2 +``` + +Generates an `export.py` script that packages weights, config, and tokenizer +files for upload to Hugging Face Hub. + +## Links + +- [BioNeMo Framework](https://github.com/NVIDIA/bionemo-framework) +- [BioNeMo Recipes README](../README.md) diff --git a/bionemo-recipes/claude-plugin/skills/add-fp8-support/SKILL.md b/bionemo-recipes/claude-plugin/skills/add-fp8-support/SKILL.md new file mode 100644 index 0000000000..4df61cf27f --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/add-fp8-support/SKILL.md @@ -0,0 +1,136 @@ +--- +name: add-fp8-support +description: > + Add FP8, MXFP8, or NVFP4 quantization support to a TransformerEngine model. + Triggers when user asks about FP8, FP4, quantization, mixed precision, + or low-precision training. +allowed-tools: Read, Grep, Glob, Write, Edit, Bash, Agent +argument-hint: '[fp8|mxfp8|nvfp4]' +--- + +# Add FP8/FP4 Quantization Support + +You are adding quantization support to a TransformerEngine model. Read the reference files first. + +## Reference Files + +- `reference/quantization.py` — Layer-wise precision assignment +- `reference/fp8_config_example.py` — FP8 recipe setup in training + +## Steps + +### 1. Add Config Fields + +Add these fields to the NV config class: + +- `layer_precision: list[str | None] | None = None` — Per-layer precision ("fp8", "fp4", None) +- `use_quantized_model_init: bool = False` — Initialize weights directly in quantized format + +Validate in `__init__`: + +```python +if layer_precision is not None: + assert len(layer_precision) == self.num_hidden_layers + for p in layer_precision: + assert p in {"fp8", "fp4", None} +``` + +### 2. Pad Vocabulary Size + +FP8 requires tensor dimensions divisible by 16. Pad vocab: + +```python +self.padded_vocab_size = padded_vocab_size or self.vocab_size +# Round up to next multiple of 16 +if self.padded_vocab_size % 16 != 0: + self.padded_vocab_size = ((self.padded_vocab_size + 15) // 16) * 16 +``` + +Update embedding and LM head to use `padded_vocab_size`. Truncate logits back to `vocab_size` in forward pass. + +### 3. Implement `get_autocast_context()` + +This method returns the appropriate TE context manager for each layer: + +```python +from contextlib import nullcontext +import transformer_engine.pytorch as te + + +def get_autocast_context(self, layer_number, init=False, outer=False): + if self.config.layer_precision is None: + return nullcontext() + + # Outer context wraps entire encoder for recipe post-processing + if outer: + if "fp8" not in self.config.layer_precision: + return nullcontext() + return te.autocast(enabled=True, recipe=self._fp8_recipe) + + precision = self.config.layer_precision[layer_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + + # During init: use quantized_model_init for weight initialization + if init and self.config.use_quantized_model_init: + if precision in ("fp8", "fp4"): + return te.quantized_model_init(recipe=recipe) + return nullcontext() + + # During forward: use autocast for precision control + if precision in ("fp8", "fp4"): + return te.autocast(enabled=True, recipe=recipe) + return te.autocast(enabled=False) # Explicitly disable for BF16 layers +``` + +### 4. Use Contexts in Model + +During layer creation: + +```python +for i in range(config.num_hidden_layers): + with self.get_autocast_context(i, init=True): + layers.append(te.TransformerLayer(...)) +``` + +During forward pass: + +```python +with self.get_autocast_context(None, outer=True): + for layer_idx, layer in enumerate(self.layers): + with self.get_autocast_context(layer_idx): + hidden_states = layer(hidden_states, ...) +``` + +### 5. Keep LM Head in Higher Precision + +```python +with te.autocast(enabled=False): + logits = self.lm_head(hidden_states) +``` + +### 6. Set Up FP8 Recipes + +In training script: + +```python +from transformer_engine.common.recipe import DelayedScaling, Format + +fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID) +model = MyTEModel(config, fp8_recipe=fp8_recipe) +``` + +Available recipes: + +- `DelayedScaling` — Classic FP8, computes scaling factors with delay +- `Float8CurrentScaling` — Per-tensor current scaling +- `Float8BlockScaling` — Block-wise scaling (MXFP8) +- `NVFP4BlockScaling` — 4-bit quantization + +### 7. Layer-wise Precision Assignment + +Use `resolve_layer_precision()` from reference to assign layers: + +```python +# In config: fp8_layers=[1,2,3], fp4_layers=[4,5,6] (1-indexed) +# Returns: ["fp8","fp8","fp8","fp4","fp4","fp4"] (0-indexed) +``` diff --git a/bionemo-recipes/claude-plugin/skills/add-fp8-support/reference/fp8_config_example.py b/bionemo-recipes/claude-plugin/skills/add-fp8-support/reference/fp8_config_example.py new file mode 100644 index 0000000000..5794e49ca3 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/add-fp8-support/reference/fp8_config_example.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: FP8 recipe setup in a training script. + +Shows how to create and use FP8/FP4 recipes with TransformerEngine models. +""" + +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + Float8CurrentScaling, + Format, + NVFP4BlockScaling, +) + + +def create_fp8_recipe(recipe_name: str = "DelayedScaling", **kwargs): + """Create an FP8 recipe by name. + + Available recipes: + - DelayedScaling: Classic FP8, scaling factors computed with delay + - Float8CurrentScaling: Per-tensor scaling computed each step + - Float8BlockScaling: Block-wise scaling (MXFP8) + - NVFP4BlockScaling: 4-bit quantization + """ + recipes = { + "DelayedScaling": DelayedScaling, + "Float8CurrentScaling": Float8CurrentScaling, + "Float8BlockScaling": Float8BlockScaling, + "NVFP4BlockScaling": NVFP4BlockScaling, + } + recipe_cls = recipes[recipe_name] + + # NOTE: Format.HYBRID uses E4M3 for forward, E5M2 for backward + if "fp8_format" not in kwargs and recipe_name != "NVFP4BlockScaling": + kwargs["fp8_format"] = Format.HYBRID + if "fp4_format" not in kwargs and recipe_name == "NVFP4BlockScaling": + kwargs["fp4_format"] = Format.E2M1 + + return recipe_cls(**kwargs) + + +# Example usage in training script: +def setup_model_with_fp8(config, layer_precision): + """Example of setting up a TE model with FP8 quantization.""" + config.layer_precision = layer_precision + + fp8_recipe = create_fp8_recipe("DelayedScaling") + + # NOTE: Pass recipe to model constructor, not as global state + # model = NVModelForMaskedLM(config, fp8_recipe=fp8_recipe) + + return config, fp8_recipe diff --git a/bionemo-recipes/claude-plugin/skills/add-fp8-support/reference/quantization.py b/bionemo-recipes/claude-plugin/skills/add-fp8-support/reference/quantization.py new file mode 100644 index 0000000000..21f9a8df06 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/add-fp8-support/reference/quantization.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: Layer-wise quantization assignment utilities. + +Demonstrates how to resolve user-specified layer lists into per-layer precision assignments. +""" + + +def resolve_layer_precision( + num_layers: int, + fp8_enabled: bool, + fp4_enabled: bool, + fp8_layers: list[int] | None, + fp4_layers: list[int] | None, +) -> list[str | None]: + """Resolve layer-wise quantization from user config. + + Takes 1-indexed layer lists and returns 0-indexed precision list. + + Examples: + # All layers FP8 + resolve_layer_precision(6, fp8_enabled=True, fp4_enabled=False, None, None) + # -> ["fp8", "fp8", "fp8", "fp8", "fp8", "fp8"] + + # Mixed: layers 1-3 FP8, layers 4-6 FP4 + resolve_layer_precision(6, True, True, [1,2,3], [4,5,6]) + # -> ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] + """ + all_layers = set(range(1, num_layers + 1)) + + if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None: + raise ValueError("Both fp8 and fp4 enabled but no layer lists specified. Provide explicit layer assignments.") + + # Auto-fill: if one format has explicit layers, other gets remaining + if fp8_enabled and fp8_layers is None: + claimed = set(fp4_layers) if fp4_layers else set() + fp8_layers = sorted(all_layers - claimed) + + if fp4_enabled and fp4_layers is None: + claimed = set(fp8_layers) if fp8_layers else set() + fp4_layers = sorted(all_layers - claimed) + + if not fp8_enabled: + fp8_layers = None + if not fp4_enabled: + fp4_layers = None + + # Validate no overlap + if fp8_layers and fp4_layers: + overlap = set(fp8_layers) & set(fp4_layers) + if overlap: + raise ValueError(f"Overlapping layers: {overlap}") + + fp8_set = set(fp8_layers) if fp8_layers else set() + fp4_set = set(fp4_layers) if fp4_layers else set() + return ["fp8" if i in fp8_set else "fp4" if i in fp4_set else None for i in range(1, num_layers + 1)] diff --git a/bionemo-recipes/claude-plugin/skills/export-to-hf-hub/SKILL.md b/bionemo-recipes/claude-plugin/skills/export-to-hf-hub/SKILL.md new file mode 100644 index 0000000000..2a92b337bd --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/export-to-hf-hub/SKILL.md @@ -0,0 +1,73 @@ +--- +name: export-to-hf-hub +description: > + Export a TransformerEngine model to HuggingFace Hub format. + Triggers when user asks to export, publish, upload to HuggingFace, + or create a model card. +allowed-tools: Read, Grep, Glob, Write, Edit, Bash, Agent +argument-hint: '[model-path] [hub-id]' +--- + +# Export TE Model to HuggingFace Hub + +You are creating an export pipeline that converts a HuggingFace model to TE format and packages it for distribution on HuggingFace Hub. + +## Reference Files + +- `reference/export_esm2.py` — Complete export script example + +## Steps + +### 1. Load and Convert + +```python +model_hf = AutoModelForMaskedLM.from_pretrained(model_id) +model_te = convert_hf_to_te(model_hf) +model_te.save_pretrained(export_path) +``` + +### 2. Save Tokenizer + +```python +tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) +tokenizer.save_pretrained(export_path) +``` + +### 3. Patch config.json with AUTO_MAP + +```python +import json + +with open(export_path / "config.json", "r") as f: + config = json.load(f) +config["auto_map"] = { + "AutoConfig": "model_file.NVConfig", + "AutoModel": "model_file.NVModel", + "AutoModelForMaskedLM": "model_file.NVModelForMaskedLM", +} +with open(export_path / "config.json", "w") as f: + json.dump(config, f, indent=2, sort_keys=True) +``` + +### 4. Copy Model Code as Remote Code + +```python +import shutil + +shutil.copy("modeling_te.py", export_path / "model_file.py") +``` + +### 5. Smoke Test + +```python +model = AutoModelForMaskedLM.from_pretrained(export_path, trust_remote_code=True) +``` + +### 6. Upload to Hub + +```python +from huggingface_hub import HfApi + +api = HfApi() +api.upload_folder(folder_path=export_path, repo_id="org/model-name") +``` diff --git a/bionemo-recipes/claude-plugin/skills/export-to-hf-hub/reference/export_esm2.py b/bionemo-recipes/claude-plugin/skills/export-to-hf-hub/reference/export_esm2.py new file mode 100644 index 0000000000..77166730b3 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/export-to-hf-hub/reference/export_esm2.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: ESM2 export script for HuggingFace Hub. + +Shows the complete export pipeline: +1. Load HF model -> convert to TE +2. Save model + tokenizer +3. Patch config with AUTO_MAP +4. Copy model code for trust_remote_code +5. Smoke test loading +""" + +import gc +import json +import shutil +from pathlib import Path + +import torch +from convert import convert_esm_hf_to_te +from modeling_esm_te import AUTO_MAP +from transformers import AutoModelForMaskedLM, AutoTokenizer + + +def export_hf_checkpoint(tag: str, export_path: Path): + """Export a HuggingFace checkpoint to TE format for Hub distribution. + + Args: + tag: HuggingFace model tag (e.g., "esm2_t6_8M_UR50D") + export_path: Directory to save exported model + """ + # NOTE: Load and convert + model_hf = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}") + model_te = convert_esm_hf_to_te(model_hf) + model_te.save_pretrained(export_path / tag) + + # NOTE: Save tokenizer alongside model + tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") + tokenizer.save_pretrained(export_path / tag) + + # NOTE: Patch config.json with AUTO_MAP for trust_remote_code loading + with open(export_path / tag / "config.json", "r") as f: + config = json.load(f) + config["auto_map"] = AUTO_MAP + with open(export_path / tag / "config.json", "w") as f: + json.dump(config, f, indent=2, sort_keys=True) + + # NOTE: Copy modeling file as the remote code file + # The AUTO_MAP references "esm_nv.NVEsmForMaskedLM" so the file must be named esm_nv.py + shutil.copy("modeling_esm_te.py", export_path / tag / "esm_nv.py") + + # NOTE: Copy license + shutil.copy("LICENSE", export_path / tag / "LICENSE") + + # Clean up to free memory + del model_hf, model_te + gc.collect() + torch.cuda.empty_cache() + + # NOTE: Smoke test - verify the exported model loads correctly + model_te = AutoModelForMaskedLM.from_pretrained( + export_path / tag, + dtype=torch.bfloat16, + trust_remote_code=True, + ) + del model_te + gc.collect() + torch.cuda.empty_cache() diff --git a/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/SKILL.md b/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/SKILL.md new file mode 100644 index 0000000000..bd6be98891 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/SKILL.md @@ -0,0 +1,97 @@ +--- +name: setup-fsdp-training +description: > + Set up FSDP2 or mFSDP distributed training for a TransformerEngine model. + Triggers when user asks about distributed training, FSDP, data parallel, + multi-GPU training, or scaling training. +allowed-tools: Read, Grep, Glob, Write, Edit, Bash, Agent +argument-hint: '[fsdp2|mfsdp]' +--- + +# Set Up FSDP2 Distributed Training + +You are setting up distributed training with PyTorch FSDP2 for a TransformerEngine model. Read the reference files first. + +## Reference Files + +- `reference/train_fsdp2.py` — Complete FSDP2 training script +- `reference/hydra_defaults.yaml` — Hydra configuration template + +## Steps + +### 1. Initialize Distributed + +```python +import torch +from torch.distributed.device_mesh import init_device_mesh + +torch.distributed.init_process_group(backend="nccl") +rank = int(os.environ["RANK"]) +local_rank = int(os.environ["LOCAL_RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +torch.cuda.set_device(local_rank) + +device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) +``` + +### 2. Create Model on Meta Device + +```python +with torch.device("meta"): + model = MyTEModel(config, fp8_recipe=fp8_recipe) +``` + +### 3. Apply FSDP Wrapping + +```python +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard + +# FP32 master weights with BF16 compute +mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, +) + +# Shard individual layers first, then the full model +for layer in model.layers: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) +fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) +``` + +### 4. Initialize Weights After Sharding + +```python +model.init_empty_weights() # Moves from meta to cuda +``` + +### 5. Create Optimizer (AFTER FSDP wrapping) + +```python +optimizer = torch.optim.AdamW( + model.parameters(), lr=4e-4, betas=(0.9, 0.98), weight_decay=0.01 +) +``` + +### 6. Training Loop + +```python +for step, batch in enumerate(dataloader): + batch = {k: v.to(device) for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + scheduler.step() + optimizer.zero_grad() +``` + +### 7. Distributed Checkpointing + +```python +import torch.distributed.checkpoint as dcp + +dcp.save({"model": model, "optimizer": optimizer}, checkpoint_id=ckpt_path) +dcp.load({"model": model, "optimizer": optimizer}, checkpoint_id=ckpt_path) +``` diff --git a/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/reference/hydra_defaults.yaml b/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/reference/hydra_defaults.yaml new file mode 100644 index 0000000000..cc1e178b71 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/reference/hydra_defaults.yaml @@ -0,0 +1,130 @@ +# Annotated Hydra defaults config for FSDP2 training. +# +# This is a generalized version of the bionemo-recipes training config. +# Fields marked with ??? are required and must be overridden. +# +# Source: bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml + +# --------------------------------------------------------------------------- +# Model config +# --------------------------------------------------------------------------- + +# Whether to use TransformerEngine layers (true) or plain HuggingFace model (false) +use_te: true + +# HuggingFace model ID or local path to config.json +# Examples: "nvidia/esm2_t6_8M_UR50D", "meta-llama/Llama-3.2-1B-Instruct" +config_name_or_path: ??? + +# Extra kwargs passed to Config.from_pretrained() +config_kwargs: {} + +# Total number of training steps +num_train_steps: ??? + +# --------------------------------------------------------------------------- +# Distributed / FSDP config +# --------------------------------------------------------------------------- + +# Use meta device for memory-efficient initialization (recommended for large models). +# Parameters are created as metadata only, then materialized after FSDP sharding. +use_meta_device: true + +# Whether to wrap model in torch.compile (may not be compatible with all features) +use_torch_compile: false + +# Context parallelism size. Set >1 for FSDP+CP with a 2D device mesh. +cp_size: 1 + +# --------------------------------------------------------------------------- +# Dataset config +# --------------------------------------------------------------------------- + +# Whether to use sequence packing (THD format) for better GPU utilization +use_sequence_packing: false + +dataset: + tokenizer_name: ??? + micro_batch_size: ??? + num_workers: 1 + max_seq_length: 1024 + # For MLM models: set mlm_probability (e.g., 0.15) + # For causal LM models: omit or set to null + mlm_probability: 0.15 + use_stateful_dataloader: false + load_dataset_kwargs: + path: ??? # HuggingFace dataset path + split: "train" + streaming: true # Streaming avoids downloading entire dataset + +# --------------------------------------------------------------------------- +# Logging config +# --------------------------------------------------------------------------- + +wandb_init_args: + name: ??? + +logger: + frequency: 100 + +# --------------------------------------------------------------------------- +# FP8 / FP4 quantization config +# --------------------------------------------------------------------------- + +# TransformerEngine FP8 config. +# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html +fp8_config: + enabled: false + # Recipe class path. Options: + # transformer_engine.common.recipe.DelayedScaling (classic FP8) + # transformer_engine.common.recipe.Float8CurrentScaling (per-tensor) + # transformer_engine.common.recipe.Float8BlockScaling (MXFP8) + fp8_recipe: transformer_engine.common.recipe.DelayedScaling + # Format: "HYBRID" (E4M3 forward, E5M2 backward) or "E4M3" + fp8_format: "HYBRID" + fp8_recipe_kwargs: {} + +# NVFP4 quantization config +fp4_config: + enabled: false + fp4_recipe: transformer_engine.common.recipe.NVFP4BlockScaling + fp4_format: "E2M1" + fp4_recipe_kwargs: {} + +# Per-layer precision assignment (1-indexed layer lists). +# null means "all layers" when the corresponding config is enabled. +# Example: fp8_layers=[1,2,3] assigns layers 1-3 to FP8. +fp8_layers: null +fp4_layers: null + +# Whether to use FP32 master weights with BF16 compute. +# Recommended for training stability. +use_fp32_master_weights: null + +# --------------------------------------------------------------------------- +# Optimizer config +# --------------------------------------------------------------------------- + +adamw_kwargs: + lr: 4e-4 + fused: true # Use CUDA fused AdamW for better performance + betas: [0.9, 0.98] + eps: 1e-8 + weight_decay: 0.01 + +# Learning rate scheduler (linear warmup then linear decay) +lr_scheduler_kwargs: + num_warmup_steps: 2_000 + num_training_steps: 500_000 + +# --------------------------------------------------------------------------- +# Checkpoint config +# --------------------------------------------------------------------------- + +checkpoint: + ckpt_dir: ??? + save_final_model: true + resume_from_checkpoint: true + save_every_n_steps: 1_000 + max_checkpoints: 5 # Keep only the latest N checkpoints + async_save: true # Async save (currently only supported with FSDP2) diff --git a/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/reference/train_fsdp2.py b/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/reference/train_fsdp2.py new file mode 100644 index 0000000000..76d12ad6d2 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/setup-fsdp-training/reference/train_fsdp2.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: FSDP2 training script for TransformerEngine models. + +Trimmed and annotated version showing the essential distributed training patterns. +Dataset-specific code and logging removed for clarity. + +Source: bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py + +Sequence: +1. Initialize distributed process group +2. Create device mesh +3. Create model (optionally on meta device) +4. Apply FSDP sharding to individual layers, then full model +5. Materialize meta-device parameters +6. Create optimizer (MUST be after FSDP wrapping) +7. Training loop with grad clipping and checkpointing +8. Clean up +""" + +import logging +import os +from contextlib import nullcontext +from dataclasses import dataclass, field + +import torch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.optim import AdamW + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Step 1: Distributed configuration +# --------------------------------------------------------------------------- +@dataclass(frozen=True) +class DistributedConfig: + """Reads RANK, LOCAL_RANK, WORLD_SIZE from env vars set by torchrun.""" + + rank: int = field(default_factory=lambda: int(os.environ.setdefault("RANK", "0"))) + local_rank: int = field(default_factory=lambda: int(os.environ.setdefault("LOCAL_RANK", "0"))) + world_size: int = field(default_factory=lambda: int(os.environ.setdefault("WORLD_SIZE", "1"))) + + def is_main_process(self) -> bool: + """Return True if this is the main (rank 0) process.""" + return self.rank == 0 + + +def train(config, model_cls, get_layer_fn, train_dataloader, num_train_steps, use_meta_device=True): + """Main training function demonstrating the FSDP2 pattern. + + Args: + config: Model config (e.g., NVEsmConfig). + model_cls: Model class (e.g., NVEsmForMaskedLM). + get_layer_fn: Function that extracts transformer layers from the model. + E.g., lambda m: m.esm.encoder.layers (encoder) + or lambda m: m.model.layers (decoder) + train_dataloader: DataLoader yielding batches. + num_train_steps: Total training steps. + use_meta_device: Whether to use meta device init (recommended for large models). + """ + # --- Step 1: Initialize distributed --- + dist_config = DistributedConfig() + device = torch.device(f"cuda:{dist_config.local_rank}") + torch.distributed.init_process_group(backend="nccl", device_id=device) + torch.cuda.set_device(dist_config.local_rank) + + # --- Step 2: Create device mesh --- + # 1D mesh = pure data parallelism. For FSDP+CP, use 2D: (dp_size, cp_size). + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(dist_config.world_size,), + mesh_dim_names=("dp",), + ) + + # --- Step 3: Create model (optionally on meta device) --- + # Meta device creates parameter metadata without allocating GPU memory. + # Parameters are materialized after FSDP sharding. + with torch.device("meta") if use_meta_device else nullcontext(): + model = model_cls(config) + + # --- Step 4: Apply FSDP sharding --- + # Mixed precision: FP32 master weights, BF16 compute. + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=False, + ) + + # CRITICAL: Shard individual layers FIRST, then the full model. + # This makes each layer an independent FSDP unit for better + # communication/computation overlap. + transformer_layers = get_layer_fn(model) + for layer in transformer_layers: + fully_shard(layer, mesh=device_mesh["dp"], mp_policy=mp_policy) + fully_shard(model, mesh=device_mesh["dp"], mp_policy=mp_policy) + + # --- Step 5: Materialize meta-device parameters --- + # MUST happen after FSDP sharding but BEFORE optimizer creation. + if use_meta_device: + model.init_empty_weights() # TE layers use reset_parameters() internally + + # --- Step 6: Create optimizer AFTER FSDP wrapping --- + # FSDP replaces original parameters with DTensor shards. + optimizer = AdamW(model.parameters(), lr=4e-4, betas=(0.9, 0.98), eps=1e-8, weight_decay=0.01) + + # --- Step 7: Training loop --- + step = 0 + while step < num_train_steps: + for batch in train_dataloader: + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + # Forward + outputs = model(**batch) + + # Backward + loss = outputs.loss + loss.backward() + + # Gradient clipping (works across FSDP shards) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Optimizer step + optimizer.step() + optimizer.zero_grad() + + if dist_config.is_main_process() and step % 100 == 0: + logger.info(f"Step {step}, Loss: {loss.item():.4f}") + + step += 1 + if step >= num_train_steps: + break + + # --- Step 8: Clean up --- + torch.distributed.destroy_process_group() diff --git a/bionemo-recipes/claude-plugin/skills/te-convert-model/SKILL.md b/bionemo-recipes/claude-plugin/skills/te-convert-model/SKILL.md new file mode 100644 index 0000000000..da337384d4 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/te-convert-model/SKILL.md @@ -0,0 +1,386 @@ +--- +name: te-convert-model +description: > + Convert a HuggingFace PreTrainedModel to use NVIDIA TransformerEngine layers. + Triggers when user asks to TE-ify, add TransformerEngine, convert for FP8, + or optimize transformer layers with TE. +allowed-tools: Read, Grep, Glob, Write, Edit, Bash, Agent +argument-hint: '[model-path or HF model ID]' +--- + +# Convert a HuggingFace Model to TransformerEngine + +You are converting a HuggingFace `PreTrainedModel` to use NVIDIA TransformerEngine (TE) layers. This enables FP8/FP4 quantization, fused attention kernels, and optimized distributed training. + +## Reference Files + +Before starting, read the reference files in this skill's `reference/` directory: + +- `esm2_convert.py` — Encoder (BERT-like) conversion pattern +- `llama3_convert.py` — Decoder (causal LM) conversion pattern +- `state.py` — State dict transformation framework +- `esm2_modeling_te.py` — Encoder TE model implementation +- `llama3_modeling_te.py` — Decoder TE model implementation + +## Step-by-Step Workflow + +### Step 1: Analyze the Source Model + +Read the model files to identify: + +- **Architecture type**: encoder (BERT, ESM, RoBERTa) vs decoder (GPT, Llama, Mistral) vs encoder-decoder +- **Attention pattern**: MHA (all heads same), GQA (grouped query), MQA (single KV head) +- **Layer structure**: Find `nn.TransformerEncoderLayer`, `nn.MultiheadAttention`, or custom attention +- **FFN pattern**: Standard (dense→activation→dense) vs SwiGLU (gate_proj, up_proj, down_proj) +- **Normalization**: LayerNorm vs RMSNorm +- **Position embeddings**: Absolute, rotary (RoPE), ALiBi, etc. + +### Step 2: Create the NV Config Class + +Extend the source model's config class with TE-specific fields: + +```python +from transformers import SomeConfig # The original config class + + +class NVSomeConfig(SomeConfig): + model_type: str = "nv_some" # New model type for HF registry + + def __init__( + self, + # TE attention format: "bshd" (padded) or "thd" (packed sequences) + attn_input_format: str = "bshd", + # Fuse Q/K/V into single parameter for optimized kernels + fuse_qkv_params: bool = True, + # Padded vocab size for FP8 (must be divisible by 16) + padded_vocab_size: int | None = None, + # Per-layer quantization: ["fp8", "fp4", None] per layer + layer_precision: list[str | None] | None = None, + # Initialize directly in quantized format + use_quantized_model_init: bool = False, + # For decoder models: causal attention mask + # self_attn_mask_type: str = "padding_causal", + **kwargs, + ): + super().__init__(**kwargs) + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.padded_vocab_size = padded_vocab_size or getattr(self, "vocab_size", None) + self.layer_precision = layer_precision + self.use_quantized_model_init = use_quantized_model_init + + # Validate layer_precision + if layer_precision is not None: + if len(layer_precision) != self.num_hidden_layers: + raise ValueError( + f"layer_precision must have length {self.num_hidden_layers}" + ) +``` + +### Step 3: Build the TE Model Class + +Replace standard attention with `transformer_engine.pytorch.TransformerLayer`: + +```python +import transformer_engine.pytorch as te +import transformer_engine.common.recipe +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding + + +class NVSomeModel(PreTrainedModel): + config_class = NVSomeConfig + + def __init__(self, config, fp8_recipe=None, fp4_recipe=None): + super().__init__(config) + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + # Embeddings (standard PyTorch, NOT TE) + self.embed_tokens = nn.Embedding(config.padded_vocab_size, config.hidden_size) + + # Build TE transformer layers + layers = [] + for i in range(config.num_hidden_layers): + with self.get_autocast_context(i, init=True): + layers.append( + te.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + # For GQA models: + num_gqa_groups=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + # Encoder: "LayerNorm", Decoder: "RMSNorm" + normalization="LayerNorm", # or "RMSNorm" + # Encoder: "gelu", Decoder: "swiglu" + activation="gelu", # or "swiglu" + # Encoder: "encoder", Decoder: omit or use default + layer_type="encoder", # omit for decoder + attn_input_format=config.attn_input_format, + self_attn_mask_type=getattr( + config, "attn_mask_type", "padding" + ), + fuse_qkv_params=config.fuse_qkv_params, + qkv_weight_interleaved=True, + layer_number=i + 1, # 1-indexed! + bias=True, # False for Llama-style + params_dtype=config.dtype, + device=( + "meta" + if torch.get_default_device() == torch.device("meta") + else "cuda" + ), + ) + ) + self.layers = nn.ModuleList(layers) + + # Final layer norm + self.norm = te.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # Or for RMSNorm: te.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Rotary embeddings + self.rotary_emb = RotaryPositionEmbedding( + config.hidden_size // config.num_attention_heads + ) +``` + +**Key patterns:** + +- `layer_number` is 1-indexed (not 0-indexed) +- Encoder models use `layer_type="encoder"` and `normalization="LayerNorm"` +- Decoder models use `normalization="RMSNorm"` and `activation="swiglu"` +- GQA models set `num_gqa_groups=config.num_key_value_heads` +- For encoder models, `num_gqa_groups=config.num_attention_heads` (MHA = all heads) + +**Override `state_dict()` to filter TE internal state:** + +```python +def state_dict(self, *args, **kwargs): + sd = super().state_dict(*args, **kwargs) + return {k: v for k, v in sd.items() if not k.endswith("_extra_state")} +``` + +**Implement `init_empty_weights()` for meta device support:** + +```python +def init_empty_weights(self): + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + self.embed_tokens.to_empty(device="cuda") + self.embed_tokens.apply(self._init_weights) + self.tie_weights() +``` + +**Implement `get_autocast_context()` for FP8/FP4:** + +```python +from contextlib import nullcontext + + +def get_autocast_context(self, layer_number, init=False, outer=False): + if self.config.layer_precision is None: + return nullcontext() + if outer: + if "fp8" not in self.config.layer_precision: + return nullcontext() + return te.autocast(enabled=True, recipe=self._fp8_recipe) + precision = self.config.layer_precision[layer_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + if init and self.config.use_quantized_model_init: + if precision in ("fp8", "fp4"): + return te.quantized_model_init(recipe=recipe) + return nullcontext() + if precision in ("fp8", "fp4"): + return te.autocast(enabled=True, recipe=recipe) + return te.autocast(enabled=False) +``` + +**LM Head — keep in higher precision:** + +```python +class NVSomeForMaskedLM(PreTrainedModel): + def forward(self, ...): + hidden = self.model(...) + # Disable FP8 for the LM head to maintain precision + with te.autocast(enabled=False): + logits = self.lm_head(hidden) + # Truncate padded vocab logits + if self.config.padded_vocab_size != self.config.vocab_size: + logits = logits[..., :self.config.vocab_size] +``` + +### Step 4: Write the State Dict Mapping + +Create a mapping dict that renames HF state dict keys to TE keys: + +**For encoder models (BERT-like):** + +```python +mapping = { + # Attention output projection + "encoder.layer.*.attention.output.dense.weight": "encoder.layers.*.self_attention.proj.weight", + "encoder.layer.*.attention.output.dense.bias": "encoder.layers.*.self_attention.proj.bias", + # Attention LayerNorm → TE's layernorm_qkv (fused LN + QKV) + "encoder.layer.*.attention.LayerNorm.weight": "encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "encoder.layer.*.attention.LayerNorm.bias": "encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", + # FFN layers → TE's layernorm_mlp (fused LN + MLP) + "encoder.layer.*.intermediate.dense.weight": "encoder.layers.*.layernorm_mlp.fc1_weight", + "encoder.layer.*.intermediate.dense.bias": "encoder.layers.*.layernorm_mlp.fc1_bias", + "encoder.layer.*.output.dense.weight": "encoder.layers.*.layernorm_mlp.fc2_weight", + "encoder.layer.*.output.dense.bias": "encoder.layers.*.layernorm_mlp.fc2_bias", + # FFN LayerNorm + "encoder.layer.*.LayerNorm.weight": "encoder.layers.*.layernorm_mlp.layer_norm_weight", + "encoder.layer.*.LayerNorm.bias": "encoder.layers.*.layernorm_mlp.layer_norm_bias", +} +``` + +**For decoder models (Llama-like):** + +```python +mapping = { + "model.embed_tokens.weight": "model.embed_tokens.weight", + "model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight", + "model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight", + "model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight", + "model.norm.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", +} +``` + +**TE naming conventions:** + +- `self_attention.layernorm_qkv` — Fused LayerNorm + QKV projection + - `.weight` = QKV weights, `.bias` = QKV bias + - `.layer_norm_weight` = LayerNorm weight, `.layer_norm_bias` = LayerNorm bias +- `self_attention.proj` — Output projection +- `layernorm_mlp` — Fused LayerNorm + MLP + - `.fc1_weight` / `.fc1_bias` = First FFN layer (or gate+up for SwiGLU) + - `.fc2_weight` / `.fc2_bias` = Second FFN layer + - `.layer_norm_weight` / `.layer_norm_bias` = LayerNorm + +### Step 5: Write Bidirectional Conversion + +**HF → TE conversion:** + +```python +import state # The transform system from reference/state.py + + +def convert_hf_to_te(model_hf, **config_kwargs): + te_config = NVSomeConfig(**model_hf.config.to_dict(), **config_kwargs) + with torch.device("meta"): + model_te = NVSomeForMaskedLM(te_config) + + return state.apply_transforms( + model_hf, + model_te, + mapping, + transforms=[ + # For encoder: pack Q/K/V into fused QKV + _pack_qkv_weight, + _pack_qkv_bias, + # For decoder: use TransformFns + # state.state_transform( + # source_key=("*.q_proj.weight", "*.k_proj.weight", "*.v_proj.weight"), + # target_key="*.layernorm_qkv.weight", + # fn=state.TransformFns.merge_qkv, + # ), + # state.state_transform( + # source_key=("*.gate_proj.weight", "*.up_proj.weight"), + # target_key="*.layernorm_mlp.fc1_weight", + # fn=state.TransformFns.merge_fc1, + # ), + ], + ) +``` + +**TE → HF conversion:** + +```python +def convert_te_to_hf(model_te, **config_kwargs): + import inspect + + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(OriginalConfig.__init__).parameters) + filtered = {k: v for k, v in te_config_dict.items() if k in valid_keys} + hf_config = OriginalConfig(**filtered, **config_kwargs) + + with torch.device("meta"): + model_hf = OriginalModel(hf_config) + + reverse_mapping = {v: k for k, v in mapping.items()} + return state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + transforms=[_unpack_qkv_weight, _unpack_qkv_bias], + state_dict_ignored_entries=[...], # tied weights, etc. + ) +``` + +**QKV packing for MHA (encoder) — interleaved format:** + +```python +@state.state_transform( + source_key=("*.query.weight", "*.key.weight", "*.value.weight"), + target_key="*.layernorm_qkv.weight", +) +def _pack_qkv_weight(ctx, query, key, value): + concat = torch.cat((query, key, value), dim=0) + num_heads = ctx.target.config.num_attention_heads + concat = concat.view(3, num_heads, -1, query.size(-1)) + concat = concat.transpose(0, 1).contiguous() + return concat.view(-1, query.size(-1)) +``` + +**QKV merging for GQA (decoder) — use TransformFns:** + +```python +state.state_transform( + source_key=("*.q_proj.weight", "*.k_proj.weight", "*.v_proj.weight"), + target_key="*.layernorm_qkv.weight", + fn=state.TransformFns.merge_qkv, +) +``` + +### Step 6: Write Golden Value Test + +```python +def test_golden_values(): + model_hf = OriginalModel.from_pretrained("model-id", dtype=torch.bfloat16).cuda() + model_te = convert_hf_to_te(model_hf) + model_te.to("cuda") + + input_data = prepare_test_input() + + with torch.no_grad(): + hf_out = model_hf(**input_data) + te_out = model_te(**input_data) + + torch.testing.assert_close(te_out.loss, hf_out.loss, atol=1e-2, rtol=1e-3) + torch.testing.assert_close(te_out.logits, hf_out.logits, atol=2.0, rtol=1e-4) +``` + +### Step 7: Add AUTO_MAP for HuggingFace Integration + +Define in the model file for `AutoModel.from_pretrained()` with `trust_remote_code=True`: + +```python +AUTO_MAP = { + "AutoConfig": "model_file_name.NVSomeConfig", + "AutoModel": "model_file_name.NVSomeModel", + "AutoModelForMaskedLM": "model_file_name.NVSomeForMaskedLM", # or ForCausalLM +} +``` + +## Important Notes + +- **Copy `state.py`** from the reference directory into the user's project. It is a standalone utility. +- **Embedding layer stays in PyTorch** — only transformer layers use TE. +- **FP32 rotary embeddings** — always compute RoPE outside `torch.autocast` for stability. +- **Tied weights** — call `self.tie_weights()` after conversion and after `init_empty_weights()`. +- **`_extra_state`** — TE adds internal state that must be filtered from `state_dict()`. +- **Vocab padding** — for FP8, pad vocab to multiple of 16; fill padding with zeros (embeddings) or min float (bias). diff --git a/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/esm2_convert.py b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/esm2_convert.py new file mode 100644 index 0000000000..be53b12f79 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/esm2_convert.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: ESM2 HF<->TE Conversion (Encoder/BERT-like pattern). + +This file demonstrates how to convert a HuggingFace encoder model to TransformerEngine. +Key patterns: +- Mapping dict with wildcard layer indices +- QKV packing: separate Q/K/V -> fused interleaved QKV +- Embedding padding for FP8 compatibility +- Bidirectional conversion (HF->TE and TE->HF) +""" + +import inspect + +# NOTE: These imports are relative - adjust for your project structure +import state +import torch +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM +from torch import nn + + +# NOTE: The mapping dict renames HF state dict keys to TE keys. +# Use "*" as wildcard for layer indices (e.g., layer.0, layer.1, ...). +# TE uses different naming: layernorm_qkv (fused LN+QKV), layernorm_mlp (fused LN+MLP) +mapping = { + # Attention output projection + "esm.encoder.layer.*.attention.output.dense.weight": "esm.encoder.layers.*.self_attention.proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "esm.encoder.layers.*.self_attention.proj.bias", + # Attention LayerNorm -> fused into TE's layernorm_qkv + "esm.encoder.layer.*.attention.LayerNorm.weight": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", + # FFN intermediate -> TE's fc1 (first linear in MLP) + "esm.encoder.layer.*.intermediate.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc1_weight", + "esm.encoder.layer.*.intermediate.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc1_bias", + # FFN output -> TE's fc2 (second linear in MLP) + "esm.encoder.layer.*.output.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc2_weight", + "esm.encoder.layer.*.output.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc2_bias", + # FFN LayerNorm -> fused into TE's layernorm_mlp + "esm.encoder.layer.*.LayerNorm.weight": "esm.encoder.layers.*.layernorm_mlp.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "esm.encoder.layers.*.layernorm_mlp.layer_norm_bias", + # Post-encoder LayerNorm (not fused) + "esm.encoder.emb_layer_norm_after.weight": "esm.encoder.emb_layer_norm_after.weight", + "esm.encoder.emb_layer_norm_after.bias": "esm.encoder.emb_layer_norm_after.bias", + # LM head + "lm_head.dense.weight": "lm_head.dense.weight", + "lm_head.dense.bias": "lm_head.dense.bias", + "lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight", + "lm_head.layer_norm.bias": "lm_head.decoder.layer_norm_bias", +} + +reverse_mapping = {v: k for k, v in mapping.items()} + + +def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module: + """Convert HuggingFace ESM2 to TransformerEngine format. + + NOTE: The pattern is: + 1. Create NV config from HF config (pass through all existing fields + add TE fields) + 2. Create empty TE model on meta device (avoids GPU memory for large models) + 3. Apply transforms to copy and reshape weights + """ + from accelerate import init_empty_weights + + te_config = NVEsmConfig(**model_hf.config.to_dict(), **config_kwargs) + with init_empty_weights(): + model_te = NVEsmForMaskedLM(te_config) + + output_model = state.apply_transforms( + model_hf, + model_te, + mapping, + [ + _pack_qkv_weight, # Merge Q/K/V weights into fused QKV + _pack_qkv_bias, # Merge Q/K/V biases into fused QKV + _pad_embeddings, # Pad embedding matrix for FP8 + _pad_decoder_weights, # Pad LM head weights + _pad_bias, # Pad LM head bias (with -inf for padding positions) + ], + ) + return output_model + + +def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: + """Convert TE format back to HuggingFace format. + + NOTE: Filter out TE-specific config keys that aren't valid for the original config class. + """ + from accelerate import init_empty_weights + from transformers import EsmConfig, EsmForMaskedLM + + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(EsmConfig.__init__).parameters) + filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys} + hf_config = EsmConfig(**filtered_config, **config_kwargs) + + with init_empty_weights(): + model_hf = EsmForMaskedLM(hf_config) + + output_model = state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [_unpack_qkv_weight, _unpack_qkv_bias, _unpad_embeddings, _unpad_decoder_weights, _unpad_bias], + state_dict_ignored_entries=["lm_head.decoder.weight"], # Tied weight + ) + output_model.post_init() + return output_model + + +# NOTE: QKV packing for MHA (Multi-Head Attention) uses interleaved format. +# For each head, Q/K/V weights are interleaved: [h0_q, h0_k, h0_v, h1_q, h1_k, h1_v, ...] +# This is required when qkv_weight_interleaved=True in TE. +@state.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), + target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", +) +def _pack_qkv_weight(ctx, query, key, value): + """Pack separate Q, K, V weights into interleaved QKV format.""" + concat_weights = torch.cat((query, key, value), dim=0) + input_shape = concat_weights.size() + num_heads = ctx.target.config.num_attention_heads + concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1]) + concat_weights = concat_weights.transpose(0, 1).contiguous() + return concat_weights.view(*input_shape) + + +@state.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), + target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", +) +def _pack_qkv_bias(ctx, query, key, value): + """Pack separate Q, K, V biases into interleaved QKV format.""" + concat_biases = torch.cat((query, key, value), dim=0) + input_shape = concat_biases.size() + num_heads = ctx.target.config.num_attention_heads + concat_biases = concat_biases.view(3, num_heads, -1) + concat_biases = concat_biases.transpose(0, 1).contiguous() + return concat_biases.view(*input_shape) + + +# NOTE: Reverse transforms for TE->HF conversion +@state.state_transform( + source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), +) +def _unpack_qkv_weight(ctx, qkv_weight): + """Unpack fused QKV weights back to separate Q, K, V.""" + num_heads = ctx.source.config.num_attention_heads + total_rows, input_dim = qkv_weight.size() + head_dim = total_rows // (3 * num_heads) + qkv_weight = qkv_weight.view(num_heads, 3, head_dim, input_dim).transpose(0, 1).contiguous() + return ( + qkv_weight[0].reshape(-1, input_dim), + qkv_weight[1].reshape(-1, input_dim), + qkv_weight[2].reshape(-1, input_dim), + ) + + +@state.state_transform( + source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + target_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), +) +def _unpack_qkv_bias(ctx, qkv_bias): + """Unpack fused QKV biases back to separate Q, K, V.""" + num_heads = ctx.source.config.num_attention_heads + total_size = qkv_bias.size(0) + head_dim = total_size // (3 * num_heads) + qkv_bias = qkv_bias.view(num_heads, 3, head_dim).transpose(0, 1).contiguous() + return qkv_bias[0].reshape(-1), qkv_bias[1].reshape(-1), qkv_bias[2].reshape(-1) + + +# NOTE: Embedding padding -- pad vocab to padded_vocab_size for FP8 compatibility. +# For embeddings: pad with zeros. For bias: pad with -inf (so softmax ignores padding). +def _pad_weights(ctx, source_embed): + target_dim = ctx.target.config.padded_vocab_size + num_padding = target_dim - source_embed.size(0) + padding = torch.zeros(num_padding, source_embed.size(1), dtype=source_embed.dtype, device=source_embed.device) + return torch.cat((source_embed, padding), dim=0) + + +def _unpad_weights(ctx, padded_embed): + return padded_embed[: ctx.target.config.vocab_size] + + +_pad_embeddings = state.state_transform( + "esm.embeddings.word_embeddings.weight", "esm.embeddings.word_embeddings.weight" +)(_pad_weights) +_pad_decoder_weights = state.state_transform("lm_head.decoder.weight", "lm_head.decoder.weight")(_pad_weights) +_unpad_embeddings = state.state_transform( + "esm.embeddings.word_embeddings.weight", "esm.embeddings.word_embeddings.weight" +)(_unpad_weights) +_unpad_decoder_weights = state.state_transform("lm_head.decoder.weight", "lm_head.decoder.weight")(_unpad_weights) + + +@state.state_transform(source_key="lm_head.bias", target_key="lm_head.decoder.bias") +def _pad_bias(ctx, source_bias): + """Pad bias with -inf so padded positions produce ~0 probability after softmax.""" + target_dim = ctx.target.config.padded_vocab_size + output_bias = torch.finfo(source_bias.dtype).min * torch.ones( + target_dim, dtype=source_bias.dtype, device=source_bias.device + ) + output_bias[: source_bias.size(0)] = source_bias + return output_bias + + +@state.state_transform(source_key="lm_head.decoder.bias", target_key="lm_head.bias") +def _unpad_bias(ctx, padded_bias): + return padded_bias[: ctx.target.config.vocab_size] diff --git a/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/esm2_modeling_te.py b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/esm2_modeling_te.py new file mode 100644 index 0000000000..0cca8e88cf --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/esm2_modeling_te.py @@ -0,0 +1,466 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: ESM2 TransformerEngine Model (Encoder/BERT-like pattern). + +This file shows how to wrap TransformerEngine layers in a HuggingFace-compatible +encoder model. Key patterns: +- NVEsmConfig extends EsmConfig with TE-specific fields +- NVEsmEncoder uses te.TransformerLayer with fused LN+QKV and LN+MLP +- NVEsmPreTrainedModel handles meta-device init, _init_weights, state_dict filtering +- LM head runs in higher precision (autocast disabled) for numerical stability +- AUTO_MAP dict enables AutoModel.from_pretrained() compatibility +""" + +from contextlib import nullcontext +from typing import ClassVar, ContextManager, Literal, Optional, Unpack + +import torch +import transformer_engine.common.recipe +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils.generic import TransformersKwargs + + +# NOTE: AUTO_MAP tells HuggingFace's Auto** classes which module contains our model. +# The prefix (e.g., "esm_nv.") must match the filename used in the exported checkpoint. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """Extended ESM config with TransformerEngine-specific fields.""" + + model_type: str = "nv_esm" + + def __init__( + self, + # NOTE: These are the key TE-specific config fields to add for any encoder model + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + layer_precision: list[str | None] | None = None, + use_quantized_model_init: bool = False, + **kwargs, + ): + """Initialize NVEsmConfig. + + Args: + qkv_weight_interleaved: If True, QKV weight is interleaved per-head [q0,k0,v0,q1,...]. + If False, it's concatenated [Q,K,V]. Must match the conversion code. + encoder_activation: Activation function ("gelu", "swiglu", etc.). + attn_input_format: "bshd" for padded batches, "thd" for packed sequences. + fuse_qkv_params: Expose single fused QKV parameter (enables QKV fusion). + micro_batch_size: Micro batch size for JIT warmup. + max_seq_length: Max sequence length for JIT warmup. + padded_vocab_size: Pad embedding to this size for FP8 alignment. + attn_mask_type: Attention mask type ("padding" for encoder). + layer_precision: Per-layer quantization: "fp8", "fp4", or None per layer. + use_quantized_model_init: Use quantized_model_init context during construction. + **kwargs: Additional config options passed to EsmConfig. + """ + super().__init__(**kwargs) + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + self.layer_precision = layer_precision + self.use_quantized_model_init = use_quantized_model_init + + # NOTE: padded_vocab_size must be >= vocab_size, used for FP8 alignment + self.padded_vocab_size = padded_vocab_size or self.vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size + + +class NVEsmEncoder(nn.Module): + """TransformerEngine-optimized ESM encoder stack.""" + + def __init__( + self, + config: NVEsmConfig, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ): + """Initialize NVEsmEncoder. + + Args: + config: The model configuration. + fp8_recipe: FP8 recipe for the encoder. + fp4_recipe: FP4 recipe for the encoder. + """ + super().__init__() + self.config = config + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + # NOTE: Default layer_precision from recipe if not explicitly set + if self.config.layer_precision is None and fp8_recipe is not None: + self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + # NOTE: Each layer is created inside get_autocast_context for proper FP8/FP4 init + layers = [] + for i in range(config.num_hidden_layers): + with self.get_autocast_context(i, init=True): + layers.append( + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + # NOTE: For MHA (not GQA), num_gqa_groups == num_attention_heads + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + ) + + self.layers = nn.ModuleList(layers) + + # NOTE: Post-encoder LayerNorm (not fused into any TransformerLayer) + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def forward(self, hidden_states, attention_mask=None, **kwargs: Unpack[TransformersKwargs]): + """Forward pass through encoder stack.""" + all_hidden_states = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + hidden_states = hidden_states.squeeze(0) + + # NOTE: Rotary embeddings must be computed in higher precision + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + + with self.get_autocast_context(None, outer=True): + for layer_idx, layer_module in enumerate(self.layers): + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + with self.get_autocast_context(layer_idx): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states or None) + + def get_autocast_context( + self, layer_number: int | None, init: bool = False, outer: bool = False + ) -> ContextManager: + """Return the appropriate TE autocast context for a given layer. + + NOTE: This handles three cases: + - init=True: Return quantized_model_init context for layer construction + - outer=True: Return global te.autocast wrapping the entire encoder forward + - Otherwise: Return per-layer te.autocast based on layer_precision + """ + if self.config.layer_precision is None: + return nullcontext() + + if outer: + if "fp8" not in self.config.layer_precision: + return nullcontext() + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe) + + precision = self.config.layer_precision[layer_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + + if init and self.config.use_quantized_model_init: + if precision in ("fp8", "fp4"): + return transformer_engine.pytorch.quantized_model_init(recipe=recipe) + return nullcontext() + + if precision == "fp8": + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + if precision == "fp4": + if recipe is None: + raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.") + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + return transformer_engine.pytorch.autocast(enabled=False) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """Base class handling TE-specific weight init, meta device support, and state_dict filtering.""" + + config_class = NVEsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = False + _no_split_modules = ("TransformerLayer", "EsmEmbeddings") + + def init_empty_weights(self): + """Move model from meta device to CUDA and initialize weights. + + NOTE: For TE layers, reset_parameters() handles both device placement + and weight initialization. For non-TE layers (embeddings), use standard init. + """ + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + self.esm.embeddings.word_embeddings.to_empty(device="cuda") + self.esm.embeddings.apply(self._init_weights) + self.tie_weights() + + def _init_weights(self, module): + """Initialize weights, skipping TE modules which handle their own init. + + NOTE: Must skip TE modules because the default HF _init_weights assumes any class + with 'LayerNorm' in the name should have weights=1.0, which breaks + LayerNormLinear/LayerNormMLP that use 'weight' for the linear part. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Filter out TE's _extra_state keys for HF compatibility. + + NOTE: TE layers add _extra_state attributes that break HF model loading. + """ + state_dict = super().state_dict(*args, **kwargs) + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """ESM2 masked language model with TransformerEngine encoder.""" + + # NOTE: Tied weights - decoder weight points to embedding weight + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + + def __init__( + self, + config: NVEsmConfig, + fp8_recipe=None, + fp4_recipe=None, + ): + """Initialize NVEsmForMaskedLM. + + Args: + config: The model configuration. + fp8_recipe: FP8 recipe for the encoder. + fp4_recipe: FP4 recipe for the encoder. + """ + super().__init__(config) + self.esm = NVEsmModel(config, add_pooling_layer=False, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + self.lm_head = NVEsmLMHead(config) + self.post_init() + + def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): + """Forward pass for masked language modeling.""" + outputs = self.esm(input_ids, attention_mask=attention_mask, **kwargs) + sequence_output = outputs[0] + + # NOTE: LM head runs with autocast DISABLED for numerical stability + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) + + # NOTE: Truncate logits back to original vocab_size (remove FP8 padding) + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + loss = None + if labels is not None: + loss = CrossEntropyLoss()( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput(loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states) + + +class NVEsmLMHead(nn.Module): + """LM head: Linear -> GELU -> LayerNormLinear (to vocab). + + NOTE: Uses quantized_model_init(enabled=False) to ensure LM head stays in + higher precision even when the rest of the model uses FP8/FP4. + """ + + def __init__(self, config: NVEsmConfig): + """Initialize NVEsmLMHead. + + Args: + config: The model configuration. + """ + super().__init__() + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size or config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + def forward(self, features, **kwargs): + """Forward pass through dense + gelu + decoder projection.""" + # NOTE: Keep LM head in higher precision to avoid numerical instability + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmModel(NVEsmPreTrainedModel): + """ESM encoder-only model with TE-optimized layers.""" + + def __init__(self, config, add_pooling_layer=True, fp8_recipe=None, fp4_recipe=None): + """Initialize NVEsmModel. + + Args: + config: The model configuration. + add_pooling_layer: Whether to add a pooling layer. + fp8_recipe: FP8 recipe for the encoder. + fp4_recipe: FP4 recipe for the encoder. + """ + super().__init__(config) + self.config = config + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config, fp8_recipe, fp4_recipe) + self.pooler = EsmPooler(config) if add_pooling_layer else None + self.post_init() + + def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, **kwargs): + """Forward pass through the ESM model.""" + if input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + # NOTE: TE expects boolean mask where True=masked, opposite of HF convention + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmEmbeddings(nn.Module): + """Embedding layer with padded vocab size for FP8 compatibility.""" + + def __init__(self, config): + """Initialize NVEsmEmbeddings. + + Args: + config: The model configuration. + """ + super().__init__() + # NOTE: Use padded_vocab_size for the embedding table + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, config.hidden_size, padding_idx=config.pad_token_id, dtype=config.dtype + ) + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + if config.position_embedding_type != "rotary": + raise ValueError("TE ESM-2 only supports rotary position embeddings") + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, **kwargs): + """Compute token embeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + embeddings = inputs_embeds + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + return embeddings diff --git a/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/llama3_convert.py b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/llama3_convert.py new file mode 100644 index 0000000000..a58fa0e3ee --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/llama3_convert.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: Llama3 HF<->TE Conversion (Decoder/Causal LM pattern). + +This file demonstrates decoder model conversion with: +- GQA (Group Query Attention) -> fused QKV with TransformFns.merge_qkv +- SwiGLU FFN (gate+up projections) -> fused fc1 with TransformFns.merge_fc1 +- Tied word embeddings handling +- Rotary embedding preservation +""" + +import inspect + +# NOTE: For decoder models, the mapping is simpler because TE's TransformerLayer +# handles more of the fusion internally. The key difference from encoder models: +# - Q/K/V are fused using TransformFns.merge_qkv (handles GQA properly) +# - Gate+Up projections are fused using TransformFns.merge_fc1 +import state +import torch + + +mapping = { + "model.embed_tokens.weight": "model.embed_tokens.weight", + # NOTE: input_layernorm -> layernorm_qkv.layer_norm_weight (fused into attention) + "model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + # Output projection stays as-is + "model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight", + # NOTE: post_attention_layernorm -> layernorm_mlp.layer_norm_weight (fused into MLP) + "model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight", + # down_proj -> fc2 (second linear in MLP) + "model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight", + "model.norm.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", +} + +reverse_mapping = {v: k for k, v in mapping.items()} + + +def convert_llama_hf_to_te(model_hf, **config_kwargs): + """Convert HuggingFace Llama to TransformerEngine. + + NOTE: Key differences from encoder conversion: + - Uses TransformFns.merge_qkv for GQA-aware Q/K/V fusion + - Uses TransformFns.merge_fc1 for gate/up projection fusion + - Handles tied word embeddings (skip lm_head.weight if tied) + - Copies rotary_emb.inv_freq separately + """ + from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM + + te_config = NVLlamaConfig(**model_hf.config.to_dict(), **config_kwargs) + with torch.device("meta"): + model_te = NVLlamaForCausalLM(te_config) + + # NOTE: Handle tied embeddings - if tied, skip lm_head.weight in target + state_dict_ignored_entries = ["lm_head.weight"] if model_hf.config.tie_word_embeddings else [] + + output_model = state.apply_transforms( + model_hf, + model_te, + mapping, + [ + # NOTE: TransformFns.merge_qkv handles GQA automatically. + # It interleaves Q heads with their corresponding K/V heads. + state.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="model.layers.*.self_attention.layernorm_qkv.weight", + fn=state.TransformFns.merge_qkv, + ), + # NOTE: For SwiGLU, gate and up projections are concatenated into fc1. + state.state_transform( + source_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + target_key="model.layers.*.layernorm_mlp.fc1_weight", + fn=state.TransformFns.merge_fc1, + ), + ], + state_dict_ignored_entries=state_dict_ignored_entries, + ) + + # NOTE: Rotary embeddings are not part of state_dict, copy manually + output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone() + return output_model + + +def convert_llama_te_to_hf(model_te, **config_kwargs): + """Convert TE Llama back to HuggingFace format.""" + from transformers import LlamaConfig, LlamaForCausalLM + + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(LlamaConfig.__init__).parameters) + filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys} + hf_config = LlamaConfig(**filtered_config, **config_kwargs) + + with torch.device("meta"): + model_hf = LlamaForCausalLM(hf_config) + + output_model = state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [ + # NOTE: split_qkv reverses merge_qkv, handling GQA head grouping + state.state_transform( + source_key="model.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + fn=state.TransformFns.split_qkv, + ), + # NOTE: split_fc1 reverses merge_fc1 via torch.chunk(2) + state.state_transform( + source_key="model.layers.*.layernorm_mlp.fc1_weight", + target_key=( + "model.layers.*.mlp.gate_proj.weight", + "model.layers.*.mlp.up_proj.weight", + ), + fn=state.TransformFns.split_fc1, + ), + ], + state_dict_ignored_entries=model_hf._tied_weights_keys, + ) + + output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone() + output_model.tie_weights() + return output_model diff --git a/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/llama3_modeling_te.py b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/llama3_modeling_te.py new file mode 100644 index 0000000000..9bf50c05df --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/llama3_modeling_te.py @@ -0,0 +1,425 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: Llama3 TransformerEngine Model (Decoder/Causal LM pattern). + +This file shows how to wrap TransformerEngine layers in a HuggingFace-compatible +decoder model. Key patterns: +- NVLlamaConfig extends LlamaConfig with minimal TE fields +- NVLlamaModel uses te.TransformerLayer with RMSNorm, SwiGLU, GQA +- Automatic BSHD->THD conversion for efficient packed-sequence processing +- NVLlamaForCausalLM with GenerationMixin for text generation support +- LM head in higher precision (autocast disabled) +- AUTO_MAP dict for AutoModelForCausalLM.from_pretrained() compatibility +""" + +from contextlib import nullcontext +from typing import ClassVar, ContextManager, Unpack + +import torch +import torch.nn as nn +import transformer_engine.common.recipe +import transformer_engine.pytorch +import transformers +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers import LlamaConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding +from transformers.utils.generic import TransformersKwargs + + +# NOTE: AUTO_MAP keys must match the model classes you want Auto** to resolve to. +# The value prefix must match the filename used in the exported checkpoint. +AUTO_MAP = { + "AutoConfig": "modeling_llama_te.NVLlamaConfig", + "AutoModel": "modeling_llama_te.NVLlamaModel", + "AutoModelForCausalLM": "modeling_llama_te.NVLlamaForCausalLM", +} + + +class NVLlamaConfig(LlamaConfig): + """Extended Llama config with TE-specific fields. + + NOTE: Decoder models need fewer TE-specific fields than encoder models because + TE's TransformerLayer handles most config via constructor args directly. + """ + + # NOTE: "thd" is preferred for decoders - enables packed sequence processing + attn_input_format: str = "thd" + self_attn_mask_type: str = "padding_causal" + + def __init__( + self, + layer_precision: list[str | None] | None = None, + use_quantized_model_init: bool = False, + **kwargs, + ): + """Initialize NVLlamaConfig. + + Args: + layer_precision: Per-layer quantization precision list. + use_quantized_model_init: Use quantized_model_init for layer init. + **kwargs: Additional config options passed to LlamaConfig. + """ + super().__init__(**kwargs) + self.layer_precision = layer_precision + self.use_quantized_model_init = use_quantized_model_init + + +class NVLlamaPreTrainedModel(PreTrainedModel): + """Base class for NVLlama models.""" + + config_class = NVLlamaConfig + base_model_prefix = "model" + _no_split_modules = ("TransformerLayer",) + _skip_keys_device_placement = ("past_key_values",) + + def init_empty_weights(self): + """Move model from meta device to CUDA and initialize weights. + + NOTE: TE layers use reset_parameters(). Non-TE layers (embed_tokens) + use standard HF init. Rotary embeddings are recomputed fresh. + """ + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + self.model.embed_tokens.to_empty(device="cuda") + self.model.embed_tokens.apply(self._init_weights) + self.model.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=self.model.config).inv_freq.to("cuda") + self.tie_weights() + + def _init_weights(self, module): + """Skip TE modules (they handle their own init via reset_parameters).""" + if module.__module__.startswith("transformer_engine.pytorch"): + return + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Filter out TE's _extra_state keys for HF compatibility.""" + state_dict = super().state_dict(*args, **kwargs) + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + + +class NVLlamaModel(NVLlamaPreTrainedModel): + """Llama3 decoder model with TransformerEngine layers.""" + + def __init__( + self, + config: LlamaConfig, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ): + """Initialize NVLlamaModel. + + Args: + config: The model configuration. + fp8_recipe: FP8 recipe for the model. + fp4_recipe: FP4 recipe for the model. + """ + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + # NOTE: Default layer_precision from recipe if not explicitly set + if self.config.layer_precision is None and fp8_recipe is not None: + self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers + + # NOTE: Embedding is standard nn.Embedding (no padding needed for decoder models) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype) + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + # NOTE: Key TransformerLayer differences from encoder: + # - bias=False (Llama uses no bias) + # - normalization="RMSNorm" (not LayerNorm) + # - activation="swiglu" (gated activation) + # - num_gqa_groups=num_key_value_heads (GQA, not MHA) + layers = [] + for layer_idx in range(config.num_hidden_layers): + with self.get_autocast_context(layer_idx, init=True): + layers.append( + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=True, + qkv_weight_interleaved=True, + normalization="RMSNorm", + activation="swiglu", + attn_input_format=config.attn_input_format, + self_attn_mask_type=config.self_attn_mask_type, + num_gqa_groups=config.num_key_value_heads, + layer_number=layer_idx + 1, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + ) + + self.layers = nn.ModuleList(layers) + self.norm = transformer_engine.pytorch.RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + + # NOTE: Use TE's RotaryPositionEmbedding but with HF's inv_freq for compatibility + self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq + + self.gradient_checkpointing = False + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + past_key_values=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + """Forward pass with BSHD to THD dynamic conversion.""" + all_hidden_states = [] + output_hidden_states = kwargs.get("output_hidden_states", False) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + # NOTE: Auto-convert BSHD -> THD for efficient packed-sequence processing. + # This enables HF-style generation (which provides BSHD inputs) to work + # transparently with TE's THD-optimized attention. + has_thd_input = [x in kwargs for x in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]] + should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" + + if should_pack_inputs: + assert attention_mask is not None + batch_size = hidden_states.size(0) + padded_seq_len = input_ids.size(1) + hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask) + kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens + kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + hidden_states = hidden_states.squeeze(0) + + if self.config.attn_input_format == "bshd" and attention_mask is not None and attention_mask.dim() == 2: + attention_mask = ~attention_mask[:, None, None, :].bool() + + # NOTE: Rotary embeddings in higher precision + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings) + + with self.get_autocast_context(None, outer=True): + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + with self.get_autocast_context(layer_idx): + hidden_states = decoder_layer( + hidden_states, + attention_mask=None if self.config.attn_input_format == "thd" else attention_mask, + rotary_pos_emb=te_rope_emb, + inference_params=past_key_values, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = (*all_hidden_states, hidden_states) + + # NOTE: Convert THD back to BSHD for HF-compatible output + if should_pack_inputs: + hidden_states = _pad_input(hidden_states, indices, batch_size, padded_seq_len) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states if output_hidden_states else None, + ) + + def get_autocast_context( + self, layer_number: int | None, init: bool = False, outer: bool = False + ) -> ContextManager: + """Return appropriate TE autocast context for a given layer. + + Same pattern as encoder -- see NVEsmEncoder.get_autocast_context for details. + """ + if self.config.layer_precision is None: + return nullcontext() + + if outer: + if "fp8" not in self.config.layer_precision: + return nullcontext() + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp8_recipe) + + precision = self.config.layer_precision[layer_number] + recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision) + + if init and self.config.use_quantized_model_init: + if precision in ("fp8", "fp4"): + return transformer_engine.pytorch.quantized_model_init(recipe=recipe) + return nullcontext() + + if precision == "fp8": + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + if precision == "fp4": + if recipe is None: + raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.") + return transformer_engine.pytorch.autocast(enabled=True, recipe=recipe) + return transformer_engine.pytorch.autocast(enabled=False) + + +class NVLlamaForCausalLM(NVLlamaPreTrainedModel, transformers.GenerationMixin): + """Llama3 causal LM with generation support. + + NOTE: Inherits GenerationMixin for HF generate() compatibility. + """ + + # NOTE: Tied weights - lm_head.weight points to embed_tokens.weight + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config, fp8_recipe=None, fp4_recipe=None): + """Initialize NVLlamaForCausalLM. + + Args: + config: The model configuration. + fp8_recipe: FP8 recipe for the model. + fp4_recipe: FP4 recipe for the model. + """ + super().__init__(config) + self.model = NVLlamaModel(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) + self.vocab_size = config.vocab_size + + # NOTE: LM head created with quantized_model_init DISABLED for numerical stability + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.lm_head = transformer_engine.pytorch.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + self.post_init() + + def forward( + self, + input_ids=None, + attention_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + shift_labels=None, + logits_to_keep=0, + **kwargs, + ) -> CausalLMOutputWithPast: + """Forward pass for causal language modeling.""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + # NOTE: LM head with autocast DISABLED for numerical stability + with transformer_engine.pytorch.autocast(enabled=False): + if hidden_states.ndim == 3: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + else: # THD format: batch and sequence collapsed in first dimension + logits = self.lm_head(hidden_states[slice_indices, :]) + + loss = None + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, shift_labels=shift_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + ) + + +# --- Helper functions for BSHD <-> THD conversion --- + +torch._dynamo.config.capture_scalar_outputs = True + + +@torch.compile +def _pad_input(hidden_states, indices, batch, seqlen): + """Convert THD tensor back to BSHD format.""" + dim = hidden_states.shape[1:] + output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) + output[indices] = hidden_states + return output.view(batch, seqlen, *dim) + + +@torch.compile +def _unpad_input(hidden_states, attention_mask, unused_mask=None): + """Convert BSHD tensor to THD format by removing padding. + + Returns: + hidden_states: (total_tokens, hidden_size) + indices: indices of non-padding tokens in flattened input + cu_seqlens: cumulative sequence lengths for TE + max_seqlen: maximum sequence length in batch + seqused: number of used tokens per sequence + """ + batch_size = hidden_states.size(0) + seq_length = hidden_states.size(1) + + if attention_mask.shape[1] != seq_length: # Generation mode with kv-caching + return ( + hidden_states.squeeze(1), + torch.arange(batch_size, dtype=torch.int64, device=hidden_states.device), + torch.arange(batch_size + 1, dtype=torch.int32, device=hidden_states.device), + 1, + 1, + ) + + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + + return ( + hidden_states.reshape(-1, *hidden_states.shape[2:])[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) diff --git a/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/state.py b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/state.py new file mode 100644 index 0000000000..29a059f334 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/te-convert-model/reference/state.py @@ -0,0 +1,546 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict conversion utilities for mapping weights between HF and TE model formats. + +This module provides the transform system used by convert.py to map state dicts: + +- ``mapping``: A dict of simple key renames (source_key -> target_key). Each source key is + copied directly to the corresponding target key with no modification to the tensor values. + +- ``transforms``: A list of ``StateDictTransform`` objects for multi-key merges and splits. + These handle cases where multiple source keys must be combined into one target key + (e.g., merging Q/K/V into fused QKV), or one source key must be split into multiple target keys. + + Important: When ``source_key`` is a tuple (many-to-one merge), the transform function's + parameter names are used to map each source key to a function argument. This means ``*args`` + style parameters do not work; each parameter must be explicitly named + (e.g., ``def fn(q, k, v)`` not ``def fn(*args)``). + +Adapted from nemo.lightning.io.state. +""" + +import inspect +import logging +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger(__name__) + +SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) +TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class TransformCTX: + """Context passed to every transform function. + + Attributes: + source: The source nn.Module (provides .config for reading hyperparams). + source_state: Flat dict of source model's state_dict entries. + target: The target nn.Module (provides .config for reading hyperparams). + target_state: Flat dict of target model's state_dict entries (mutated in-place). + """ + + source: nn.Module + source_state: dict + target: nn.Module + target_state: dict + + +class _ModelState: + """Helper to wrap a raw state_dict as a source model for apply_transforms.""" + + def __init__(self, state_dict, config=None): + self._state_dict = state_dict + self.config = config + + def state_dict(self): + return self._state_dict + + def to(self, dtype): + for k, v in self._state_dict.items(): + if v.dtype != dtype: + logger.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + self._state_dict[k] = v.to(dtype) + + +@torch.no_grad +def apply_transforms( + source: Union[nn.Module, _ModelState], + target: TargetModuleT, + mapping: Dict[str, str], + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None, + state_dict_ignored_entries: Optional[List] = None, + cast_dtype: Optional[torch.dtype] = None, +) -> TargetModuleT: + """Transform source state dict to match target model structure. + + 1. Applies simple key renames from ``mapping`` + 2. Applies each transform (merge/split operations) + 3. Copies tensors into target model parameters and buffers + 4. Validates shapes and checks for leftover/missing keys + + Args: + source: Source model or _ModelState wrapper. + target: Target model (parameters will be replaced in-place). + mapping: Simple key renames {source_key: target_key}. Use "*" as wildcard. + transforms: List of StateDictTransform instances for complex operations. + state_dict_ignored_entries: Target state dict keys to skip (e.g., tied weights). + cast_dtype: If set, cast output model to this dtype. + + Returns: + The target model with weights populated from source. + + Raises: + ValueError: Shape mismatch between source and target parameters. + RuntimeError: Unmapped keys remain in target state dict. + """ + if transforms is None: + transforms = [] + if state_dict_ignored_entries is None: + state_dict_ignored_entries = [] + + target_orig_dtypes = extract_dtypes(target.named_parameters()) + + target_state = target.state_dict() + ctx = TransformCTX( + source=source, + source_state=source.state_dict(), + target=target, + target_state=target_state, + ) + + # Step 1: Apply simple key renames + for key, val in mapping.items(): + ctx = StateDictTransform(key, val)(ctx) + + # Step 2: Apply complex transforms (QKV merge, embedding padding, etc.) + for transform in transforms: + ctx = transform(ctx) + + # Step 3: Copy tensors into target model parameters + _params: Dict[str, nn.Parameter] = {} + for name, param in target.named_parameters(): + if name in target_state: + target_param = target_state[name] + if param.data.shape != target_param.shape: + raise ValueError( + f"Shape mismatch for parameter {name}: target shape {param.shape} vs " + f"converted source shape {target_param.shape}" + ) + _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) + target_state.pop(name) + else: + print(f"Unexpected key: {name} not in target model but is in source model.") + + for key, val in _params.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + _module.register_parameter(_key, val) + + # Step 4: Copy buffers + _buffers = {} + for name, buffer in target.named_buffers(): + if name in target_state: + if buffer.shape != target_state[name].shape: + raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") + _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) + target_state.pop(name) + + for key, val in _buffers.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + _module.register_buffer(_key, val) + + # Step 5: Validate no unmapped keys remain + keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) + keys = [key for key in keys if key not in state_dict_ignored_entries] + if len(keys) != 0: + raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.") + + if hasattr(target, "tie_weights"): + target.tie_weights() + + # Step 6: Verify no meta tensors remain (all weights were converted) + meta_tensor_keys = [] + for name, param in target.named_parameters(): + if param.is_meta: + meta_tensor_keys.append(name) + assert not meta_tensor_keys, ( + f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." + f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" + ) + + if cast_dtype: + target.to(cast_dtype) + else: + target_new_dtypes = extract_dtypes(target.named_parameters()) + for key in target_orig_dtypes.keys(): + if key in target_new_dtypes: + assert target_orig_dtypes[key] == target_new_dtypes[key], ( + f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}" + ) + + return target + + +def _default_transform(inp): + return inp + + +class StateDictTransform(Generic[F]): + """A transformation that maps keys between source and target state dicts. + + Supports wildcards (*) in key patterns for matching layer indices. + Can handle 1:1 renames, N:1 merges, and 1:N splits. + + Args: + source_key: Source key pattern(s). Use "*" for layer index wildcards. + target_key: Target key pattern(s). + transform: Callable that transforms tensor values. Receives TransformCTX + as first arg (if it accepts 'ctx'), plus matched source tensors. + """ + + def __init__( + self, + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + transform: F = _default_transform, + ): + """Initialize StateDictTransform with source/target key patterns and transform function.""" + self.source_key = source_key + self.target_key = target_key + self.transform = transform + + def __call__(self, ctx: TransformCTX) -> TransformCTX: + """Perform the transformation on the given context.""" + source_key = self.source_key + target_key = self.target_key + source_dict, target_dict = ctx.source_state, ctx.target_state + np.set_printoptions(threshold=10) + fn_params = dict(inspect.signature(self.transform).parameters) + fn_params.pop("ctx", None) + matched = False + + if isinstance(source_key, (dict, tuple)): + # Multi-source merge: e.g., (q_proj, k_proj, v_proj) -> layernorm_qkv + if isinstance(source_key, tuple): + source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} + else: + source_key_dict = source_key + source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} + target_matches = _match_keys(list(target_dict.keys()), target_key) + param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): + if isinstance(layer_names_group[0], str): + layer_names_group = [[x] for x in layer_names_group] + for layer_names in zip(*layer_names_group): + target_dict[layer_names[-1]] = self.call_transform( + ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) + ) + matched = True + else: + # Single-source: 1:1 rename or 1:N split + source_keys = list(source_dict.keys()) + target_keys = list(target_dict.keys()) + + source_matches = _match_keys(source_keys, source_key) + if source_matches.size == 1 and source_matches == np.array(None): + raise ValueError(f"No matches found for source key: {source_key}") + + if isinstance(target_key, str): + target_matches = _match_keys(target_keys, target_key) + if target_matches.size == 1 and target_matches == np.array(None): + raise ValueError(f"No matches found for target key: {target_key}") + else: + if isinstance(target_key, dict): + raise ValueError("Target key must be a string or a tuple of strings.") + _matches = [_match_keys(target_keys, key) for key in target_key] + target_matches = np.stack(_matches, axis=-1) + + multiple_sources = source_matches.ndim >= target_matches.ndim + accepts_var_args = any( + param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() + ) + + if multiple_sources: + for target_index, target_match in np.ndenumerate(target_matches): + try: + source_match = source_matches[target_index] + except IndexError as e: + logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}") + raise e + if accepts_var_args: + source_values = [source_dict[k] for k in source_match] + target_dict[target_match] = self.call_transform(ctx, *source_values) + else: + _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) + if len(fn_params) != len(_source_match_list): + raise ValueError( + f"Mismatch between source and target keys: {source_match} vs {target_match}" + ) + kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} + target_dict[target_match] = self.call_transform(ctx, **kwargs) + matched = True + else: + for source_index, source_match in np.ndenumerate(source_matches): + target_match = target_matches[source_index] + source_values = ( + [source_dict[source_match]] + if np.isscalar(source_match) + else [source_dict[k] for k in source_match] + ) + if accepts_var_args: + outputs = self.call_transform(ctx, *source_values) + else: + kwargs = dict(zip(fn_params, source_values)) + outputs = self.call_transform(ctx, **kwargs) + + if isinstance(target_match, str): + target_dict[target_match] = outputs + else: + for i, t in enumerate(outputs): + target_dict[target_match[i]] = t + matched = True + + if not matched: + logger.warning(f"No matches found for source key: {source_key=} {target_key=}") + return ctx + + def call_transform(self, ctx: TransformCTX, *args, **kwargs): + """Invoke transform fn, injecting ctx if the function accepts it.""" + func_params = inspect.signature(self.transform).parameters + expected_num_args = len([p for p in func_params if p not in ["self", "ctx"]]) + provided_num_args = len(args) + len(kwargs) + accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) + + if not accepts_var_args and provided_num_args != expected_num_args: + raise ValueError( + f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." + ) + + if "ctx" in func_params: + return self.transform(ctx, *args, **kwargs) + + return self.transform(*args, **kwargs) + + +def _match_keys(keys: List[str], pattern: str) -> np.ndarray: + """Match state dict keys against a pattern with wildcards. + + Supports: + - "*" matches a single path segment (e.g., layer index) + - "**" matches any characters including dots + + Returns an ndarray where each dimension corresponds to a wildcard position. + """ + escaped_pattern = "" + i = 0 + wildcard_positions = [] + while i < len(pattern): + if pattern[i : i + 2] == "**": + escaped_pattern += r"(.+)" + wildcard_positions.append("**") + i += 2 + elif pattern[i] == "*": + escaped_pattern += r"([^.]+)" + wildcard_positions.append("*") + i += 1 + else: + if pattern[i] == ".": + escaped_pattern += r"\." + else: + escaped_pattern += pattern[i] + i += 1 + + regex_pattern = re.compile("^" + escaped_pattern + "$") + num_wildcards = len(wildcard_positions) + wildcard_matches = [[] for _ in range(num_wildcards)] + + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + for i, group in enumerate(match.groups()): + if group not in wildcard_matches[i]: + wildcard_matches[i].append(group) + + for i in range(len(wildcard_matches)): + wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) + + shape = [len(matches) for matches in wildcard_matches] + if len(wildcard_matches) == 0: + shape = [1] + + output_array = np.empty(shape, dtype=object) + + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] + output_array[tuple(indices)] = key + + return output_array + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], +) -> Callable[[F], StateDictTransform[F]]: ... + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F +) -> StateDictTransform[F]: ... + + +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + fn: Optional[F] = None, +): + """Create a StateDictTransform. Can be used as a decorator or called directly. + + Usage as decorator: + @state_transform(source_key="a.*.weight", target_key="b.*.weight") + def my_transform(ctx, weight): + return weight * 2 + + Usage with fn argument (inline): + state_transform(source_key=(...), target_key="...", fn=TransformFns.merge_qkv) + """ + + def wrapper(fn) -> StateDictTransform: + return StateDictTransform(source_key, target_key, fn) + + if fn is None: + return wrapper + return wrapper(fn) + + +class TransformFns: + """Common transform functions for state dict conversion.""" + + @staticmethod + def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): + """Split interleaved fused QKV into separate Q, K, V tensors. + + Handles GQA by computing the correct slicing for grouped query heads. + """ + target_config = ctx.target.config + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) + hidden_size = linear_qkv.size(-1) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + return q_proj, k_proj, v_proj + + @staticmethod + def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge separate Q, K, V into interleaved fused QKV tensor. + + Handles GQA: for each query group, interleaves Q heads with their K/V heads. + Layout: [group0_q_heads, group0_k, group0_v, group1_q_heads, group1_k, group1_v, ...] + """ + target_config = ctx.target.config + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:]) + new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:]) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + return qkv_weights + + @staticmethod + def merge_fc1(gate: torch.Tensor, up: torch.Tensor): + """Merge gate and up projections into concatenated fc1 (for SwiGLU).""" + return torch.cat((gate, up), dim=0) + + @staticmethod + def split_fc1(linear_fc1: torch.Tensor): + """Split concatenated fc1 back into gate and up projections.""" + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + return gate_proj, up_proj + + @staticmethod + def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): + """Prune embedding to original vocab_size (remove FP8 padding).""" + return embedding[: ctx.target.config.vocab_size, :] + + +def extract_dtypes(ckpt): + """Extract dtype for each parameter/tensor in a named iterator.""" + dtypes = {} + for key, val in ckpt: + if hasattr(val, "dtype"): + dtypes[key] = val.dtype + elif hasattr(val, "data") and hasattr(val.data, "dtype"): + dtypes[key] = val.data.dtype + return dtypes diff --git a/bionemo-recipes/claude-plugin/skills/write-golden-tests/SKILL.md b/bionemo-recipes/claude-plugin/skills/write-golden-tests/SKILL.md new file mode 100644 index 0000000000..f20ea5998d --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/write-golden-tests/SKILL.md @@ -0,0 +1,107 @@ +--- +name: write-golden-tests +description: > + Write golden value tests and conversion tests for a TransformerEngine model. + Triggers when user asks to test TE conversion, write golden tests, + validate model equivalence, or verify conversion correctness. +allowed-tools: Read, Grep, Glob, Write, Edit, Bash, Agent +argument-hint: '[test-dir-path]' +--- + +# Write Golden Value and Conversion Tests + +You are writing tests that prove a TransformerEngine model produces identical outputs to the original HuggingFace model. Read the reference test files first. + +## Reference Files + +- `reference/test_esm2_example.py` — Encoder model test pattern +- `reference/test_llama3_example.py` — Decoder model test pattern + +## Test Categories + +### 1. Golden Value Test (Most Important) + +Proves numerical equivalence between HF and TE models: + +```python +def test_golden_values(self): + model_hf = OriginalModel.from_pretrained(model_id, dtype=torch.bfloat16).cuda() + model_te = convert_hf_to_te(model_hf) + model_te.to("cuda") + + input_data = self.prepare_test_input() + + with torch.no_grad(): + hf_out = model_hf(**input_data) + te_out = model_te(**input_data) + + # Loss should be very close + torch.testing.assert_close(te_out.loss, hf_out.loss, atol=1e-2, rtol=1e-3) + # Logits may have larger absolute differences but small relative error + torch.testing.assert_close(te_out.logits, hf_out.logits, atol=2.0, rtol=1e-4) +``` + +### 2. Roundtrip Conversion Test + +Proves HF->TE->HF preserves weights: + +```python +def test_roundtrip_conversion(self): + model_hf_orig = OriginalModel.from_pretrained(model_id) + model_te = convert_hf_to_te(model_hf_orig) + model_hf_back = convert_te_to_hf(model_te) + + for (name_orig, param_orig), (name_back, param_back) in zip( + model_hf_orig.named_parameters(), model_hf_back.named_parameters() + ): + torch.testing.assert_close( + param_orig, param_back, msg=f"Mismatch in {name_orig}" + ) +``` + +### 3. Forward/Backward Smoke Test + +```python +def test_forward_backward(self): + model_te = create_te_model() + input_data = self.prepare_test_input() + output = model_te(**input_data) + output.loss.backward() + # Verify gradients exist + for name, param in model_te.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" +``` + +### 4. FP8 Smoke Test + +```python +def test_fp8_forward_backward(self): + from transformer_engine.common.recipe import DelayedScaling, Format + + config = create_config(layer_precision=["fp8"] * num_layers) + recipe = DelayedScaling(fp8_format=Format.HYBRID) + model = MyTEModel(config, fp8_recipe=recipe).cuda() + input_data = self.prepare_test_input() + output = model(**input_data) + output.loss.backward() +``` + +### 5. Meta Device Init Test + +```python +def test_meta_device_init(self): + with torch.device("meta"): + model = MyTEModel(config) + model.init_empty_weights() + # Verify no meta tensors remain + for name, param in model.named_parameters(): + assert not param.is_meta, f"{name} still on meta device" +``` + +## Test Tolerances + +- **Loss**: atol=1e-2, rtol=1e-3 (should be very close) +- **Logits**: atol=2.0, rtol=1e-4 (larger absolute due to accumulated numerical differences) +- **Hidden states**: atol=0.1, rtol=0.05 +- **FP8 loss**: atol=0.1, rtol=0.05 (FP8 introduces more error) diff --git a/bionemo-recipes/claude-plugin/skills/write-golden-tests/reference/test_esm2_example.py b/bionemo-recipes/claude-plugin/skills/write-golden-tests/reference/test_esm2_example.py new file mode 100644 index 0000000000..5a4dabff31 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/write-golden-tests/reference/test_esm2_example.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: ESM2 test implementation (encoder pattern). + +Shows how to implement golden value tests for an encoder model. +Key patterns: +- Inheriting from a base test class +- Preparing protein sequence test data with MLM collation +- Setting model-specific tolerances +""" + +import torch +from convert import convert_esm_hf_to_te, convert_esm_te_to_hf +from transformers import AutoTokenizer, DataCollatorForLanguageModeling + + +class TestESM2Model: + """Test suite for ESM2 TE model.""" + + upstream_model_id = "facebook/esm2_t6_8M_UR50D" + + def get_test_input_data(self): + tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") + test_proteins = [ + "MLSATEKLSDYISSLFASVSIINSISTEDLFFLK", + "MFVFFAGTLVNQDTLNFRDQLNINVVGTVRGIAQ", + ] + tokenized = [tokenizer(p, truncation=True, max_length=128) for p in test_proteins] + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) + batch = collator(tokenized) + return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + def test_golden_values(self): + """HF and TE models should produce matching outputs.""" + from transformers.models.esm.modeling_esm import EsmForMaskedLM + + model_hf = EsmForMaskedLM.from_pretrained(self.upstream_model_id, dtype=torch.bfloat16).cuda() + model_te = convert_esm_hf_to_te(model_hf).cuda() + input_data = self.get_test_input_data() + + with torch.no_grad(): + hf_out = model_hf(**input_data) + te_out = model_te(**input_data) + + # NOTE: These tolerances are model-specific. ESM2 needs slightly higher due to + # numerical differences in TE's fused attention vs HF's unfused attention. + torch.testing.assert_close(te_out.loss, hf_out.loss, atol=2e-2, rtol=1e-2) + torch.testing.assert_close(te_out.logits, hf_out.logits, atol=2.0, rtol=1e-4) + + def test_roundtrip_conversion(self): + """HF->TE->HF should preserve all weights.""" + from transformers.models.esm.modeling_esm import EsmForMaskedLM + + model_hf = EsmForMaskedLM.from_pretrained(self.upstream_model_id).cuda() + model_te = convert_esm_hf_to_te(model_hf) + model_hf_back = convert_esm_te_to_hf(model_te) + + for (n1, p1), (n2, p2) in zip(model_hf.named_parameters(), model_hf_back.named_parameters()): + torch.testing.assert_close(p1, p2, msg=f"Roundtrip mismatch: {n1}") + + def test_forward_backward(self): + """Smoke test: forward + backward pass should work.""" + from transformers.models.esm.modeling_esm import EsmForMaskedLM + + model_hf = EsmForMaskedLM.from_pretrained(self.upstream_model_id, dtype=torch.bfloat16).cuda() + model_te = convert_esm_hf_to_te(model_hf).cuda() + input_data = self.get_test_input_data() + + output = model_te(**input_data) + output.loss.backward() + + for name, param in model_te.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" diff --git a/bionemo-recipes/claude-plugin/skills/write-golden-tests/reference/test_llama3_example.py b/bionemo-recipes/claude-plugin/skills/write-golden-tests/reference/test_llama3_example.py new file mode 100644 index 0000000000..da2aac6e25 --- /dev/null +++ b/bionemo-recipes/claude-plugin/skills/write-golden-tests/reference/test_llama3_example.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reference: Llama3 test implementation (decoder pattern). + +Shows how to implement golden value tests for a decoder/causal LM model. +Key differences from encoder tests: +- Uses text sequences instead of protein sequences +- DataCollatorForLanguageModeling with mlm=False (causal LM) +- Tests both BSHD and THD input formats +- Can test generation/KV-cache +""" + +import torch +from convert import convert_llama_hf_to_te, convert_llama_te_to_hf +from transformers import AutoTokenizer, DataCollatorForLanguageModeling + + +class TestLlama3Model: + """Test suite for Llama3 TE model.""" + + upstream_model_id = "meta-llama/Llama-3.2-1B-Instruct" + + def get_test_input_data(self): + tokenizer = AutoTokenizer.from_pretrained(self.upstream_model_id) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + test_texts = [ + "The quick brown fox jumps over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + ] + collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + batch = collator([tokenizer(text) for text in test_texts]) + return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + + def test_golden_values(self): + """HF and TE models should produce matching outputs.""" + import transformers + + model_hf = transformers.LlamaForCausalLM.from_pretrained(self.upstream_model_id, dtype=torch.bfloat16).cuda() + model_te = convert_llama_hf_to_te(model_hf).cuda() + input_data = self.get_test_input_data() + + with torch.no_grad(): + hf_out = model_hf(**input_data) + te_out = model_te(**input_data) + + torch.testing.assert_close(te_out.loss, hf_out.loss, atol=5e-3, rtol=0.01) + torch.testing.assert_close(te_out.logits, hf_out.logits, atol=1.5, rtol=0.01) + + def test_roundtrip_conversion(self): + """HF->TE->HF preserves weights.""" + import transformers + + model_hf = transformers.LlamaForCausalLM.from_pretrained(self.upstream_model_id).cuda() + model_te = convert_llama_hf_to_te(model_hf) + model_hf_back = convert_llama_te_to_hf(model_te) + + for (n1, p1), (n2, p2) in zip(model_hf.named_parameters(), model_hf_back.named_parameters()): + torch.testing.assert_close(p1, p2, msg=f"Roundtrip mismatch: {n1}") diff --git a/bionemo-recipes/integration-tests/README.md b/bionemo-recipes/integration-tests/README.md new file mode 100644 index 0000000000..ca71c7614d --- /dev/null +++ b/bionemo-recipes/integration-tests/README.md @@ -0,0 +1,38 @@ +# BioNeMo Recipes Integration Tests + +Automated tests that validate Claude Code + the bionemo-recipes plugin can successfully +convert vanilla HuggingFace models to use TransformerEngine. + +## Prerequisites + +- Claude Code CLI installed (`npm install -g @anthropic-ai/claude-code`) +- `ANTHROPIC_API_KEY` environment variable set +- Python with pytest + +## Running Tests + +```bash +cd bionemo-recipes/integration-tests +pip install -r requirements.txt +pytest -v --timeout=600 +``` + +## Test Structure + +- `fixtures/barebones-bert/` — Minimal vanilla BERT model (encoder, MHA) +- `fixtures/barebones-llama/` — Minimal vanilla Llama model (decoder, GQA, SwiGLU) +- `test_te_conversion.py` — Tests that Claude can TE-ify both model types +- `test_fp8_addition.py` — Tests that Claude can add FP8 to an existing TE model +- `test_export.py` — Tests that Claude can create an export script +- `validators/` — Code validation utilities (AST, pattern matching, file checks) + +## How It Works + +1. Each test copies a fixture model to a temp directory +2. Sends a prompt to Claude Code with the bionemo-recipes plugin loaded +3. Validates the generated code using AST parsing and pattern matching +4. Optionally runs the generated tests on GPU (if available) + +## Cost + +These tests call the Claude API. Each test costs approximately $1-5 USD depending on complexity. diff --git a/bionemo-recipes/integration-tests/conftest.py b/bionemo-recipes/integration-tests/conftest.py new file mode 100644 index 0000000000..7455e3f117 --- /dev/null +++ b/bionemo-recipes/integration-tests/conftest.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test fixtures for Claude Code + bionemo-recipes plugin.""" + +import json +import shutil +import subprocess +from pathlib import Path + +import pytest + + +PLUGIN_DIR = Path(__file__).parent.parent / "claude-plugin" +FIXTURES_DIR = Path(__file__).parent / "fixtures" + + +class ClaudeRunner: + """Runs Claude Code in --print mode with the bionemo-recipes plugin.""" + + def run(self, prompt: str, cwd: str, max_budget: float = 5.0, timeout: int = 600) -> dict: + """Execute a Claude Code prompt and return the JSON result. + + Args: + prompt: The prompt to send to Claude. + cwd: Working directory for Claude. + max_budget: Maximum API budget in USD. + timeout: Timeout in seconds. + + Returns: + Parsed JSON output from Claude. + """ + cmd = [ + "claude", + "-p", + prompt, + "--output-format", + "json", + "--dangerously-skip-permissions", + "--add-dir", + str(PLUGIN_DIR), + f"--max-budget-usd={max_budget}", + ] + result = subprocess.run( + cmd, + cwd=cwd, + capture_output=True, + text=True, + timeout=timeout, + ) + if result.returncode != 0: + raise RuntimeError( + f"Claude exited with code {result.returncode}.\n" + f"stdout: {result.stdout[:2000]}\n" + f"stderr: {result.stderr[:2000]}" + ) + return json.loads(result.stdout) + + +@pytest.fixture(scope="session") +def claude_runner(): + """Provide a ClaudeRunner instance for the test session.""" + return ClaudeRunner() + + +@pytest.fixture +def bert_fixture_dir(tmp_path): + """Create a temporary copy of the barebones-bert fixture.""" + src = FIXTURES_DIR / "barebones-bert" + dst = tmp_path / "barebones-bert" + shutil.copytree(src, dst) + return dst + + +@pytest.fixture +def llama_fixture_dir(tmp_path): + """Create a temporary copy of the barebones-llama fixture.""" + src = FIXTURES_DIR / "barebones-llama" + dst = tmp_path / "barebones-llama" + shutil.copytree(src, dst) + return dst + + +@pytest.fixture +def pre_te_ified_bert_dir(tmp_path, claude_runner): + """Create a TE-ified BERT fixture (for FP8 tests that need a pre-converted model).""" + src = FIXTURES_DIR / "barebones-bert" + dst = tmp_path / "pre-te-bert" + shutil.copytree(src, dst) + + claude_runner.run( + "Convert this HuggingFace BERT model to use TransformerEngine. " + "Create the TE model class, conversion utilities, and a basic test.", + cwd=str(dst), + max_budget=5.0, + ) + return dst diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-bert/config.py b/bionemo-recipes/integration-tests/fixtures/barebones-bert/config.py new file mode 100644 index 0000000000..2a6a0a2189 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-bert/config.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import PretrainedConfig + + +class SimpleBertConfig(PretrainedConfig): + """Minimal BERT configuration for testing.""" + + model_type = "simple_bert" + + def __init__( # noqa: D107 + self, + vocab_size=1000, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=512, + max_position_embeddings=128, + layer_norm_eps=1e-5, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + pad_token_id=0, + mask_token_id=4, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.mask_token_id = mask_token_id diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-bert/modeling_simple_bert.py b/bionemo-recipes/integration-tests/fixtures/barebones-bert/modeling_simple_bert.py new file mode 100644 index 0000000000..bb4f643417 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-bert/modeling_simple_bert.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from config import SimpleBertConfig +from transformers import PreTrainedModel +from transformers.modeling_outputs import MaskedLMOutput + + +class SimpleBertModel(PreTrainedModel): + """Minimal BERT encoder using nn.TransformerEncoder.""" + + config_class = SimpleBertConfig + + def __init__(self, config: SimpleBertConfig): # noqa: D107 + super().__init__(config) + self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.embedding_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=config.hidden_size, + nhead=config.num_attention_heads, + dim_feedforward=config.intermediate_size, + dropout=config.hidden_dropout_prob, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers) + + def forward(self, input_ids, attention_mask=None): # noqa: D102 + seq_len = input_ids.size(1) + position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + + embeddings = self.token_embeddings(input_ids) + self.position_embeddings(position_ids) + embeddings = self.embedding_layer_norm(embeddings) + embeddings = self.embedding_dropout(embeddings) + + # Convert attention_mask from (batch, seq) to (batch, seq) bool mask for TransformerEncoder + src_key_padding_mask = None + if attention_mask is not None: + src_key_padding_mask = attention_mask == 0 # True = ignore + + hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask) + return hidden_states + + +class SimpleBertForMaskedLM(PreTrainedModel): + """Minimal BERT for masked language modeling.""" + + config_class = SimpleBertConfig + + def __init__(self, config: SimpleBertConfig): # noqa: D107 + super().__init__(config) + self.bert = SimpleBertModel(config) + + # LM head: Linear -> GELU -> LayerNorm -> Linear + self.lm_head_dense = nn.Linear(config.hidden_size, config.hidden_size) + self.lm_head_act = nn.GELU() + self.lm_head_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.lm_head_proj = nn.Linear(config.hidden_size, config.vocab_size) + + def forward(self, input_ids, attention_mask=None, labels=None): # noqa: D102 + hidden_states = self.bert(input_ids, attention_mask=attention_mask) + + # LM head + x = self.lm_head_dense(hidden_states) + x = self.lm_head_act(x) + x = self.lm_head_norm(x) + logits = self.lm_head_proj(x) + + loss = None + if labels is not None: + loss_fn = nn.CrossEntropyLoss(ignore_index=-100) + loss = loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + return MaskedLMOutput(loss=loss, logits=logits) diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-bert/requirements.txt b/bionemo-recipes/integration-tests/fixtures/barebones-bert/requirements.txt new file mode 100644 index 0000000000..e6ef03f4f6 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-bert/requirements.txt @@ -0,0 +1,3 @@ +torch>=2.0 +transformers>=4.40 +pytest diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-bert/test_model.py b/bionemo-recipes/integration-tests/fixtures/barebones-bert/test_model.py new file mode 100644 index 0000000000..bd8beb220e --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-bert/test_model.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from config import SimpleBertConfig +from modeling_simple_bert import SimpleBertForMaskedLM + + +def _make_model(): + config = SimpleBertConfig() + return SimpleBertForMaskedLM(config) + + +def _make_inputs(config, batch_size=4, seq_len=32): + input_ids = torch.randint(1, config.vocab_size, (batch_size, seq_len)) + labels = torch.full_like(input_ids, -100) + mask_pos = torch.rand(batch_size, seq_len) < 0.15 + labels[mask_pos] = input_ids[mask_pos] + input_ids[mask_pos] = config.mask_token_id + return input_ids, labels + + +def test_forward_pass(): + model = _make_model() + input_ids, labels = _make_inputs(model.config) + output = model(input_ids, labels=labels) + assert output.logits.shape == (4, 32, model.config.vocab_size) + assert output.loss is not None + assert output.loss.item() > 0 + + +def test_backward_pass(): + model = _make_model() + input_ids, labels = _make_inputs(model.config) + output = model(input_ids, labels=labels) + output.loss.backward() + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + +def test_loss_decreases(): + model = _make_model() + model.train() + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + input_ids, labels = _make_inputs(model.config) + + losses = [] + for _ in range(10): + optimizer.zero_grad() + output = model(input_ids, labels=labels) + output.loss.backward() + optimizer.step() + losses.append(output.loss.item()) + + assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-bert/train.py b/bionemo-recipes/integration-tests/fixtures/barebones-bert/train.py new file mode 100644 index 0000000000..4d0a4fb2a1 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-bert/train.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from config import SimpleBertConfig +from modeling_simple_bert import SimpleBertForMaskedLM + + +def main(): + """Train a minimal BERT model for a few steps to verify it works.""" + config = SimpleBertConfig() + model = SimpleBertForMaskedLM(config) + model.train() + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + # Random training data + batch_size, seq_len = 4, 32 + input_ids = torch.randint(1, config.vocab_size, (batch_size, seq_len)) + + # Create labels: mask ~15% of tokens, rest are -100 (ignored) + labels = torch.full_like(input_ids, -100) + mask_positions = torch.rand(batch_size, seq_len) < 0.15 + labels[mask_positions] = input_ids[mask_positions] + input_ids[mask_positions] = config.mask_token_id + + losses = [] + for step in range(10): + optimizer.zero_grad() + output = model(input_ids, labels=labels) + loss = output.loss + loss.backward() + optimizer.step() + losses.append(loss.item()) + print(f"Step {step}: loss={loss.item():.4f}") + + assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" + print("Training complete. Loss decreased successfully.") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-llama/config.py b/bionemo-recipes/integration-tests/fixtures/barebones-llama/config.py new file mode 100644 index 0000000000..3fe6aeb489 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-llama/config.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal Llama-style causal LM configuration.""" + +from transformers import PretrainedConfig + + +class SimpleLlamaConfig(PretrainedConfig): + """Minimal Llama-style configuration.""" + + model_type = "simple_llama" + + def __init__( # noqa: D107 + self, + vocab_size=1000, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=512, + max_position_embeddings=128, + rms_norm_eps=1e-5, + pad_token_id=0, + tie_word_embeddings=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.rms_norm_eps = rms_norm_eps diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-llama/modeling_simple_llama.py b/bionemo-recipes/integration-tests/fixtures/barebones-llama/modeling_simple_llama.py new file mode 100644 index 0000000000..1f31c97b29 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-llama/modeling_simple_llama.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal Llama-style causal LM with Group Query Attention. + +This is a vanilla PyTorch implementation with NO TransformerEngine. +It uses separate Q/K/V projections and SwiGLU FFN, which are the key +patterns that TE conversion needs to handle (Q/K/V fusion, gate/up fusion). +""" + +from typing import ClassVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +from config import SimpleLlamaConfig +from transformers import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization.""" + + def __init__(self, hidden_size, eps=1e-5): # noqa: D107 + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): # noqa: D102 + variance = x.float().pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return (self.weight * x).to(x.dtype) + + +def apply_rotary_pos_emb(x, cos, sin): # noqa: D103 + x1, x2 = x[..., ::2], x[..., 1::2] + rotated = torch.stack((-x2, x1), dim=-1).flatten(-2) + return x * cos + rotated * sin + + +class SimpleLlamaAttention(nn.Module): + """GQA attention with separate Q/K/V projections.""" + + def __init__(self, config): # noqa: D107 + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + + # NOTE: Separate projections — TE conversion will fuse these + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states, cos, sin): # noqa: D102 + bsz, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(hidden_states).view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + q = apply_rotary_pos_emb(q, cos, sin) + k = apply_rotary_pos_emb(k, cos, sin) + + # GQA: repeat K/V heads + if self.num_kv_groups > 1: + k = k.repeat_interleave(self.num_kv_groups, dim=1) + v = v.repeat_interleave(self.num_kv_groups, dim=1) + + attn = F.scaled_dot_product_attention(q, k, v, is_causal=True) + attn = attn.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + return self.o_proj(attn) + + +class SimpleLlamaMLP(nn.Module): + """SwiGLU FFN with separate gate/up projections.""" + + def __init__(self, config): # noqa: D107 + super().__init__() + # NOTE: Separate gate and up projections — TE conversion will fuse these + self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + def forward(self, x): # noqa: D102 + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class SimpleLlamaLayer(nn.Module): + """Single Llama transformer layer with pre-norm and residual connections.""" + + def __init__(self, config): # noqa: D107 + super().__init__() + self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.self_attn = SimpleLlamaAttention(config) + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.mlp = SimpleLlamaMLP(config) + + def forward(self, hidden_states, cos, sin): # noqa: D102 + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states, cos, sin) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + return residual + hidden_states + + +class SimpleLlamaModel(PreTrainedModel): + """Llama base model with token embeddings and transformer layers.""" + + config_class = SimpleLlamaConfig + + def __init__(self, config): # noqa: D107 + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.layers = nn.ModuleList([SimpleLlamaLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) + + # Precompute rotary embeddings — repeat each freq for the interleaved (even/odd) layout + head_dim = config.hidden_size // config.num_attention_heads + freqs = 1.0 / (10000.0 ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(config.max_position_embeddings).float() + emb = torch.outer(t, freqs) + # Repeat each frequency so shape is (seq, head_dim) matching x[..., ::2]/x[..., 1::2] interleaving + cos = emb.cos().repeat_interleave(2, dim=-1) + sin = emb.sin().repeat_interleave(2, dim=-1) + self.register_buffer("cos_cached", cos[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", sin[None, None, :, :], persistent=False) + self.post_init() + + def forward(self, input_ids, **kwargs): # noqa: D102 + hidden_states = self.embed_tokens(input_ids) + seq_len = input_ids.size(1) + cos = self.cos_cached[:, :, :seq_len, :] + sin = self.sin_cached[:, :, :seq_len, :] + + for layer in self.layers: + hidden_states = layer(hidden_states, cos, sin) + + return self.norm(hidden_states) + + +class SimpleLlamaForCausalLM(PreTrainedModel): + """Llama causal language model with tied embeddings.""" + + config_class = SimpleLlamaConfig + _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.weight": "model.embed_tokens.weight"} + + def __init__(self, config): # noqa: D107 + super().__init__(config) + self.model = SimpleLlamaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): # noqa: D102 + return self.model.embed_tokens + + def get_output_embeddings(self): # noqa: D102 + return self.lm_head + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): # noqa: D102 + hidden_states = self.model(input_ids) + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) + + return CausalLMOutputWithPast(loss=loss, logits=logits) diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-llama/requirements.txt b/bionemo-recipes/integration-tests/fixtures/barebones-llama/requirements.txt new file mode 100644 index 0000000000..e6ef03f4f6 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-llama/requirements.txt @@ -0,0 +1,3 @@ +torch>=2.0 +transformers>=4.40 +pytest diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-llama/test_model.py b/bionemo-recipes/integration-tests/fixtures/barebones-llama/test_model.py new file mode 100644 index 0000000000..2f0e0ca03a --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-llama/test_model.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sanity tests for SimpleLlamaForCausalLM.""" + +import torch +from config import SimpleLlamaConfig +from modeling_simple_llama import SimpleLlamaForCausalLM + + +def test_forward_pass(): + config = SimpleLlamaConfig() + model = SimpleLlamaForCausalLM(config) + input_ids = torch.randint(0, config.vocab_size, (2, 16)) + output = model(input_ids=input_ids) + assert output.logits.shape == (2, 16, config.vocab_size) + + +def test_backward_pass(): + config = SimpleLlamaConfig() + model = SimpleLlamaForCausalLM(config) + input_ids = torch.randint(0, config.vocab_size, (2, 16)) + labels = input_ids.clone() + output = model(input_ids=input_ids, labels=labels) + output.loss.backward() + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"No gradient for {name}" + + +def test_loss_decreases(): + torch.manual_seed(42) + config = SimpleLlamaConfig() + model = SimpleLlamaForCausalLM(config) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + # Fixed batch to overfit on + input_ids = torch.randint(5, config.vocab_size, (4, 32)) + labels = input_ids.clone() + + losses = [] + for _ in range(20): + output = model(input_ids=input_ids, labels=labels) + output.loss.backward() + optimizer.step() + optimizer.zero_grad() + losses.append(output.loss.item()) + + assert losses[-1] < losses[0] + + +def test_gqa_dimensions(): + config = SimpleLlamaConfig(num_attention_heads=4, num_key_value_heads=2) + model = SimpleLlamaForCausalLM(config) + attn = model.model.layers[0].self_attn + assert attn.q_proj.out_features == 4 * (config.hidden_size // 4) + assert attn.k_proj.out_features == 2 * (config.hidden_size // 4) + assert attn.v_proj.out_features == 2 * (config.hidden_size // 4) diff --git a/bionemo-recipes/integration-tests/fixtures/barebones-llama/train.py b/bionemo-recipes/integration-tests/fixtures/barebones-llama/train.py new file mode 100644 index 0000000000..11b0abe0e7 --- /dev/null +++ b/bionemo-recipes/integration-tests/fixtures/barebones-llama/train.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Basic training script for SimpleLlamaForCausalLM.""" + +import torch +from config import SimpleLlamaConfig +from modeling_simple_llama import SimpleLlamaForCausalLM + + +def main(): + """Train a minimal Llama model for a few steps to verify it works.""" + torch.manual_seed(42) + config = SimpleLlamaConfig() + model = SimpleLlamaForCausalLM(config) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + # Fixed batch to overfit on — proves the model can learn + input_ids = torch.randint(5, config.vocab_size, (4, 32)) + labels = input_ids.clone() + labels[:, :5] = -100 + + losses = [] + for step in range(20): + output = model(input_ids=input_ids, labels=labels) + loss = output.loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + + losses.append(loss.item()) + print(f"Step {step}: loss={loss.item():.4f}") + + assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}" + print("Training complete!") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/integration-tests/requirements.txt b/bionemo-recipes/integration-tests/requirements.txt new file mode 100644 index 0000000000..2951d86b11 --- /dev/null +++ b/bionemo-recipes/integration-tests/requirements.txt @@ -0,0 +1,2 @@ +pytest>=7.0 +pytest-timeout>=2.0 diff --git a/bionemo-recipes/integration-tests/test_export.py b/bionemo-recipes/integration-tests/test_export.py new file mode 100644 index 0000000000..6c3c6baa1b --- /dev/null +++ b/bionemo-recipes/integration-tests/test_export.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests: Claude creates an export script.""" + +from validators import ast_checks + + +class TestExport: + """Test that Claude can create a HuggingFace Hub export script.""" + + def test_claude_creates_export_script(self, claude_runner, pre_te_ified_bert_dir): + """Claude creates an export script for a TE model.""" + claude_runner.run( + "Create an export script that converts the HuggingFace model to TE format " + "and saves it for HuggingFace Hub distribution. Include AUTO_MAP patching " + "in config.json for trust_remote_code support.", + cwd=str(pre_te_ified_bert_dir), + ) + + export_file = pre_te_ified_bert_dir / "export.py" + assert export_file.exists(), "export.py not created" + ast_checks.validate_python_file(export_file) + + code = export_file.read_text() + assert "AUTO_MAP" in code or "auto_map" in code, "export.py does not reference AUTO_MAP" diff --git a/bionemo-recipes/integration-tests/test_fp8_addition.py b/bionemo-recipes/integration-tests/test_fp8_addition.py new file mode 100644 index 0000000000..63e406e81e --- /dev/null +++ b/bionemo-recipes/integration-tests/test_fp8_addition.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests: Claude adds FP8 support to a TE model.""" + +from validators import ast_checks, pattern_checks + + +class TestFp8Addition: + """Test that Claude can add FP8 support to an existing TE model.""" + + def test_claude_adds_fp8_to_bert(self, claude_runner, pre_te_ified_bert_dir): + """Claude adds FP8 quantization support to a TE-ified BERT model.""" + claude_runner.run( + "Add FP8 quantization support to this TransformerEngine model. " + "Add layer_precision config, get_autocast_context() method, " + "and vocabulary padding for FP8 compatibility.", + cwd=str(pre_te_ified_bert_dir), + ) + + te_model_candidates = list(pre_te_ified_bert_dir.glob("*te*.py")) + list( + pre_te_ified_bert_dir.glob("*_nv*.py") + ) + assert len(te_model_candidates) > 0, "No TE model file found" + + te_model_file = te_model_candidates[0] + ast_checks.validate_python_file(te_model_file) + pattern_checks.has_layer_precision_config(te_model_file) + pattern_checks.has_fp8_autocast(te_model_file) diff --git a/bionemo-recipes/integration-tests/test_te_conversion.py b/bionemo-recipes/integration-tests/test_te_conversion.py new file mode 100644 index 0000000000..e7b2121fe6 --- /dev/null +++ b/bionemo-recipes/integration-tests/test_te_conversion.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests: Claude converts vanilla models to TransformerEngine.""" + +from validators import ast_checks, pattern_checks + + +class TestBertTeConversion: + """Test that Claude can TE-ify a vanilla BERT model.""" + + def test_claude_creates_te_model(self, claude_runner, bert_fixture_dir): + """Claude converts a vanilla BERT to use TransformerEngine.""" + claude_runner.run( + "Convert this HuggingFace BERT model to use TransformerEngine. " + "Create: 1) A TE model file with NV config and model classes, " + "2) A convert.py with bidirectional conversion and state dict mapping, " + "3) A basic golden value test. " + "Follow the patterns from the bionemo-recipes reference files.", + cwd=str(bert_fixture_dir), + ) + + te_model_candidates = list(bert_fixture_dir.glob("*te*.py")) + list(bert_fixture_dir.glob("*_nv*.py")) + assert len(te_model_candidates) > 0, "No TE model file created" + + convert_file = bert_fixture_dir / "convert.py" + assert convert_file.exists(), "convert.py not created" + + for py_file in bert_fixture_dir.glob("*.py"): + ast_checks.validate_python_file(py_file) + + te_model_file = te_model_candidates[0] + pattern_checks.has_te_imports(te_model_file) + pattern_checks.has_state_dict_mapping(convert_file) + pattern_checks.has_bidirectional_conversion(convert_file) + + +class TestLlamaTeConversion: + """Test that Claude can TE-ify a vanilla Llama model (decoder with GQA).""" + + def test_claude_creates_te_model(self, claude_runner, llama_fixture_dir): + """Claude converts a vanilla Llama to use TransformerEngine.""" + claude_runner.run( + "Convert this HuggingFace Llama-style causal LM model to use TransformerEngine. " + "This model uses Group Query Attention (GQA) with separate Q/K/V projections " + "and a SwiGLU FFN with gate/up projections. " + "Create: 1) A TE model file, 2) A convert.py with bidirectional conversion, " + "3) A basic golden value test. " + "Follow the patterns from the bionemo-recipes reference files.", + cwd=str(llama_fixture_dir), + ) + + te_model_candidates = list(llama_fixture_dir.glob("*te*.py")) + list(llama_fixture_dir.glob("*_nv*.py")) + assert len(te_model_candidates) > 0, "No TE model file created" + + convert_file = llama_fixture_dir / "convert.py" + assert convert_file.exists(), "convert.py not created" + + for py_file in llama_fixture_dir.glob("*.py"): + ast_checks.validate_python_file(py_file) + + te_model_file = te_model_candidates[0] + pattern_checks.has_te_imports(te_model_file) + pattern_checks.has_state_dict_mapping(convert_file) + pattern_checks.has_bidirectional_conversion(convert_file) diff --git a/bionemo-recipes/integration-tests/validators/__init__.py b/bionemo-recipes/integration-tests/validators/__init__.py new file mode 100644 index 0000000000..16192a30bc --- /dev/null +++ b/bionemo-recipes/integration-tests/validators/__init__.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validation utilities for integration tests.""" + +from validators.ast_checks import validate_python_file, validate_python_syntax +from validators.file_checks import assert_file_not_empty, assert_files_exist +from validators.pattern_checks import ( + has_bidirectional_conversion, + has_fp8_autocast, + has_layer_precision_config, + has_state_dict_mapping, + has_te_imports, + has_vocab_padding, +) + + +__all__ = [ + "assert_file_not_empty", + "assert_files_exist", + "has_bidirectional_conversion", + "has_fp8_autocast", + "has_layer_precision_config", + "has_state_dict_mapping", + "has_te_imports", + "has_vocab_padding", + "validate_python_file", + "validate_python_syntax", +] diff --git a/bionemo-recipes/integration-tests/validators/ast_checks.py b/bionemo-recipes/integration-tests/validators/ast_checks.py new file mode 100644 index 0000000000..9f4cc82888 --- /dev/null +++ b/bionemo-recipes/integration-tests/validators/ast_checks.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AST-based validation of generated Python files.""" + +import ast +from pathlib import Path + + +def validate_python_syntax(code: str) -> bool: + """Check that a string of Python code parses without syntax errors.""" + try: + ast.parse(code) + return True + except SyntaxError: + return False + + +def validate_python_file(filepath: Path) -> None: + """Assert that a Python file has valid syntax.""" + assert filepath.exists(), f"File does not exist: {filepath}" + code = filepath.read_text() + try: + ast.parse(code) + except SyntaxError as e: + raise AssertionError(f"Syntax error in {filepath}: {e}") from e + + +def has_class(filepath: Path, class_name: str) -> bool: + """Check if a Python file defines a class with the given name.""" + code = filepath.read_text() + tree = ast.parse(code) + return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree)) + + +def has_function(filepath: Path, func_name: str) -> bool: + """Check if a Python file defines a function with the given name.""" + code = filepath.read_text() + tree = ast.parse(code) + return any( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func_name for node in ast.walk(tree) + ) + + +def count_classes(filepath: Path) -> int: + """Count the number of class definitions in a Python file.""" + code = filepath.read_text() + tree = ast.parse(code) + return sum(1 for node in ast.walk(tree) if isinstance(node, ast.ClassDef)) diff --git a/bionemo-recipes/integration-tests/validators/file_checks.py b/bionemo-recipes/integration-tests/validators/file_checks.py new file mode 100644 index 0000000000..3b26a143c8 --- /dev/null +++ b/bionemo-recipes/integration-tests/validators/file_checks.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""File existence and basic content checks.""" + +from pathlib import Path + + +def assert_files_exist(base_dir: Path, filenames: list[str]) -> None: + """Assert that all specified files exist in the base directory.""" + missing = [f for f in filenames if not (base_dir / f).exists()] + assert not missing, f"Missing expected files: {missing}" + + +def assert_file_not_empty(filepath: Path) -> None: + """Assert that a file exists and is not empty.""" + assert filepath.exists(), f"File does not exist: {filepath}" + assert filepath.stat().st_size > 0, f"File is empty: {filepath}" diff --git a/bionemo-recipes/integration-tests/validators/functional_checks.py b/bionemo-recipes/integration-tests/validators/functional_checks.py new file mode 100644 index 0000000000..80455a4cd2 --- /dev/null +++ b/bionemo-recipes/integration-tests/validators/functional_checks.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional validation - runs generated code if GPU is available.""" + +import subprocess +import sys +from pathlib import Path + +import pytest + + +def run_generated_tests(test_dir: Path, timeout: int = 120) -> None: + """Run pytest in the generated test directory. + + Requires GPU. Skips if no GPU available. + """ + try: + import torch + + if not torch.cuda.is_available(): + pytest.skip("No GPU available for functional tests") + except ImportError: + pytest.skip("torch not available for functional tests") + + result = subprocess.run( + [sys.executable, "-m", "pytest", "-v", str(test_dir)], + capture_output=True, + text=True, + timeout=timeout, + cwd=str(test_dir), + ) + assert result.returncode == 0, f"Generated tests failed:\n{result.stdout}\n{result.stderr}" diff --git a/bionemo-recipes/integration-tests/validators/pattern_checks.py b/bionemo-recipes/integration-tests/validators/pattern_checks.py new file mode 100644 index 0000000000..47400ba358 --- /dev/null +++ b/bionemo-recipes/integration-tests/validators/pattern_checks.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pattern-based validation of generated code for TE conversion patterns.""" + +import re +from pathlib import Path + + +def _read(filepath: Path) -> str: + """Read file contents.""" + assert filepath.exists(), f"File does not exist: {filepath}" + return filepath.read_text() + + +def has_te_imports(filepath: Path) -> None: + """Assert that the file imports TransformerEngine modules.""" + code = _read(filepath) + assert "transformer_engine" in code, f"{filepath.name} does not import transformer_engine" + # Should import at least TransformerLayer or the pytorch module + assert any( + pattern in code + for pattern in [ + "transformer_engine.pytorch.TransformerLayer", + "transformer_engine.pytorch", + "from transformer_engine", + "import transformer_engine", + ] + ), f"{filepath.name} does not import TE TransformerLayer or pytorch module" + + +def has_state_dict_mapping(filepath: Path) -> None: + """Assert that the file contains a state dict mapping dictionary.""" + code = _read(filepath) + # Should have a mapping dict with wildcard patterns + assert re.search(r"mapping\s*=\s*\{", code) or re.search(r'["\'].*\*.*["\']\s*:', code), ( + f"{filepath.name} does not contain a state dict mapping with wildcards" + ) + + +def has_bidirectional_conversion(filepath: Path) -> None: + """Assert that the file has both HF->TE and TE->HF conversion functions.""" + code = _read(filepath) + has_hf_to_te = bool(re.search(r"def\s+convert_\w+_hf_to_te", code)) + has_te_to_hf = bool(re.search(r"def\s+convert_\w+_te_to_hf", code)) + assert has_hf_to_te, f"{filepath.name} missing convert_*_hf_to_te function" + assert has_te_to_hf, f"{filepath.name} missing convert_*_te_to_hf function" + + +def has_layer_precision_config(filepath: Path) -> None: + """Assert that a config or model file has layer_precision support.""" + code = _read(filepath) + assert "layer_precision" in code, f"{filepath.name} does not reference layer_precision" + + +def has_fp8_autocast(filepath: Path) -> None: + """Assert that a model file uses TE autocast for FP8.""" + code = _read(filepath) + assert any( + pattern in code + for pattern in [ + "transformer_engine.pytorch.autocast", + "te.autocast", + "autocast(enabled=", + ] + ), f"{filepath.name} does not use TE autocast" + + +def has_vocab_padding(filepath: Path) -> None: + """Assert that the model/config handles vocabulary padding for FP8.""" + code = _read(filepath) + assert "padded_vocab_size" in code or "pad" in code.lower(), f"{filepath.name} does not handle vocabulary padding"