-
Notifications
You must be signed in to change notification settings - Fork 135
Add MXFP8 and NVFP4 quantization support to LLaMA3 #1500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jomitchellnv
wants to merge
6
commits into
main
Choose a base branch
from
jm/mxfp8-nvfp4-llama3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
243 changes: 243 additions & 0 deletions
243
bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,243 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: LicenseRef-Apache2 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| import pickle | ||
| import subprocess | ||
|
|
||
| import pytest | ||
| import torch | ||
| from transformer_engine.pytorch.fp8 import check_fp8_support | ||
|
|
||
|
|
||
| def requires_fp8(func): | ||
| """Decorator to skip tests that require FP8 support.""" | ||
| fp8_available, reason = check_fp8_support() | ||
| return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func) | ||
|
|
||
|
|
||
| requires_multi_gpu = pytest.mark.skipif( | ||
| not torch.cuda.is_available() or torch.cuda.device_count() < 2, | ||
| reason="Test requires at least 2 GPUs", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("strategy", ["ddp", "fsdp2"]) | ||
| @requires_fp8 | ||
| def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): | ||
| cmd = [ | ||
| "torchrun", | ||
| "--nproc_per_node=1", | ||
| "--rdzv-backend=c10d", | ||
| f"--rdzv-endpoint=localhost:{unused_tcp_port}", | ||
| os.path.relpath(__file__), | ||
| "--strategy", | ||
| strategy, | ||
| ] | ||
|
|
||
| result = subprocess.run( | ||
| cmd, | ||
| check=False, | ||
| text=True, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.PIPE, | ||
| timeout=240, | ||
| ) | ||
| if result.returncode != 0: | ||
| print(f"STDOUT:\n{result.stdout}") | ||
| print(f"STDERR:\n{result.stderr}") | ||
| pytest.fail(f"Command failed with exit code {result.returncode}") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("strategy", ["ddp", "fsdp2"]) | ||
| @requires_fp8 | ||
| @requires_multi_gpu | ||
| def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port): | ||
| cmd = [ | ||
| "torchrun", | ||
| "--nproc_per_node=2", | ||
| "--rdzv-backend=c10d", | ||
| f"--rdzv-endpoint=localhost:{unused_tcp_port}", | ||
| os.path.relpath(__file__), | ||
| "--strategy", | ||
| strategy, | ||
| ] | ||
|
|
||
| result = subprocess.run( | ||
| cmd, | ||
| check=False, | ||
| text=True, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.PIPE, | ||
| timeout=240, | ||
| ) | ||
| if result.returncode != 0: | ||
| print(f"STDOUT:\n{result.stdout}") | ||
| print(f"STDERR:\n{result.stderr}") | ||
| pytest.fail(f"Command failed with exit code {result.returncode}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import argparse | ||
| import enum | ||
| import os | ||
| import sys | ||
| from dataclasses import dataclass, field | ||
| from pathlib import Path | ||
|
|
||
| # Ensure the model directory is on sys.path for bare module imports. | ||
| sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix()) | ||
|
|
||
| import torch.distributed as dist | ||
| from torch.distributed.device_mesh import init_device_mesh | ||
| from torch.distributed.fsdp import fully_shard | ||
| from torch.optim import AdamW | ||
| from transformer_engine.pytorch.fp8 import DelayedScaling, Format | ||
|
|
||
| from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM | ||
|
|
||
| def recursive_assert(a, b, path=""): | ||
| if isinstance(a, dict) and isinstance(b, dict): | ||
| assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}" | ||
| for k in a: | ||
| recursive_assert(a[k], b[k], path=f"{path}.{k}") | ||
| elif isinstance(a, list) and isinstance(b, list): | ||
| assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}" | ||
| for i in range(len(a)): | ||
| recursive_assert(a[i], b[i], path=f"{path}.{i}") | ||
| elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): | ||
| torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}") | ||
| else: | ||
| assert a == b, f"Value mismatch at {path}: {a} != {b}" | ||
|
|
||
| class Strategy(enum.StrEnum): | ||
| DDP = "ddp" | ||
| FSDP2 = "fsdp2" | ||
|
|
||
| @dataclass | ||
| class DistributedConfig: | ||
| """Class to track distributed ranks.""" | ||
|
|
||
| rank: int = field(default_factory=dist.get_rank) | ||
| local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"])) | ||
| world_size: int = field(default_factory=dist.get_world_size) | ||
|
|
||
| def is_main_process(self) -> bool: | ||
| """This is the global rank 0 process, to be used for wandb logging, etc.""" | ||
| return self.rank == 0 | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP]) | ||
| args = parser.parse_args() | ||
|
|
||
| torch.distributed.init_process_group(backend="nccl") | ||
| dist_config = DistributedConfig() | ||
| torch.cuda.set_device(dist_config.local_rank) | ||
| device_mesh = init_device_mesh( | ||
| "cuda", | ||
| mesh_shape=(dist_config.world_size, 1), | ||
| mesh_dim_names=("dp", "tp"), | ||
| ) | ||
| device = f"cuda:{dist_config.local_rank}" | ||
|
|
||
| fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10) | ||
|
|
||
| config = NVLlamaConfig( | ||
| hidden_size=256, | ||
| intermediate_size=512, | ||
| num_hidden_layers=6, | ||
| num_attention_heads=8, | ||
| num_key_value_heads=4, | ||
| vocab_size=100, | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| config.layer_precision = ["fp8"] * config.num_hidden_layers | ||
| model = NVLlamaForCausalLM(config) | ||
|
|
||
| if args.strategy is Strategy.FSDP2: | ||
| for layer in model.model.layers: | ||
| fully_shard(layer, mesh=device_mesh["dp"]) | ||
| fully_shard(model, mesh=device_mesh["dp"]) | ||
| model.to(device) | ||
|
|
||
| elif args.strategy is Strategy.DDP: | ||
| model.to(device) | ||
| model = torch.nn.parallel.DistributedDataParallel( | ||
| model, | ||
| device_ids=[dist_config.local_rank], | ||
| output_device=dist_config.local_rank, | ||
| device_mesh=device_mesh["dp"], | ||
| ) | ||
|
|
||
| optimizer = AdamW(model.parameters()) | ||
|
|
||
| # Attach FP8 recipes to the model (layer precision is already on config). | ||
| llama_model = model.module.model if args.strategy is Strategy.DDP else model.model | ||
| llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) | ||
|
|
||
| model.train() | ||
|
|
||
| generator = torch.Generator() | ||
| generator.manual_seed(torch.distributed.get_rank()) | ||
|
|
||
| for _ in range(3): | ||
| input_data = { | ||
| "input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator), | ||
| "labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator), | ||
| "attention_mask": torch.ones(1, 32), | ||
| } | ||
| input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()} | ||
|
|
||
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | ||
| outputs = model(**input_data) | ||
|
|
||
| outputs.loss.backward() | ||
|
|
||
| # Access FP8 extra states directly from modules instead of state_dict() | ||
| # since state_dict() now filters them out for HuggingFace compatibility | ||
| fp8_extra_states = {} | ||
| for name, module in model.named_modules(): | ||
| if hasattr(module, "_extra_state") and callable(module._extra_state): | ||
| extra_state = module._extra_state() | ||
| if extra_state is not None and len(extra_state) > 0: | ||
| fp8_extra_states[f"{name}._extra_state"] = extra_state | ||
|
|
||
| # lm_head is BF16, not FP8, so exclude it from FP8 checks | ||
| fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key} | ||
|
|
||
| # 2 ranks, test to ensure that both ranks have the same FP8 extra states | ||
| if torch.distributed.get_world_size() == 2: | ||
| outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None | ||
| torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0) | ||
| if torch.distributed.get_rank() == 0: | ||
| assert outputs_list is not None | ||
|
|
||
| for key in outputs_list[0]: | ||
| state_1 = outputs_list[0][key] | ||
| state_2 = outputs_list[1][key] | ||
| assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0" | ||
| assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1" | ||
| dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes()) | ||
| dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes()) | ||
| recursive_assert(dict_1, dict_2) | ||
|
|
||
| # One rank, test to ensure the correct FP8 extra states are saved | ||
| if torch.distributed.get_world_size() == 1: | ||
| for key, val in fp8_extra_states.items(): | ||
| assert len(val) > 0, f"No FP8 extra states for {key}" | ||
| fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes()) | ||
| assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}" | ||
|
|
||
| torch.distributed.destroy_process_group() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's going on here