diff --git a/.devcontainer/recipes/Dockerfile b/.devcontainer/recipes/Dockerfile index c300eb7b25..0bf0921ac2 100644 --- a/.devcontainer/recipes/Dockerfile +++ b/.devcontainer/recipes/Dockerfile @@ -6,6 +6,12 @@ FROM nvcr.io/nvidia/pytorch:26.02-py3 # Remove once bug has been addressed in the nvidia/pytorch container. RUN rm -f /usr/local/lib/python*/dist-packages/transformer_engine-*.dist-info/direct_url.json +RUN --mount=type=cache,target=/var/cache/apt \ + --mount=type=cache,target=/var/lib/apt \ + apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y tmux && \ + rm -rf /var/lib/apt/lists/* + RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \ PIP_CONSTRAINT= pip install -r /workspace/requirements.txt diff --git a/bionemo-recipes/models/esm2/tests/common/fixtures.py b/bionemo-recipes/models/esm2/tests/common/fixtures.py index a437aae3d1..000ebd9429 100644 --- a/bionemo-recipes/models/esm2/tests/common/fixtures.py +++ b/bionemo-recipes/models/esm2/tests/common/fixtures.py @@ -102,6 +102,26 @@ def fp8_recipe(request): return request.param +RECIPE_NAME_TO_FACTORY = { + "DelayedScaling": recipe_module.DelayedScaling, + "Float8CurrentScaling": recipe_module.Float8CurrentScaling, + "Float8BlockScaling": recipe_module.Float8BlockScaling, + "MXFP8BlockScaling": recipe_module.MXFP8BlockScaling, + "NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +} + + +def recipe_to_name(recipe): + """Convert a recipe instance to its CLI-passable string name.""" + return type(recipe).__name__ + + +def recipe_from_name(name): + """Reconstruct a recipe instance from its CLI-passable string name.""" + factory = RECIPE_NAME_TO_FACTORY[name] + return factory() + + @pytest.fixture(params=["bshd", "thd"]) def input_format(request): """Fixture to parametrize the input format.""" diff --git a/bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py b/bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py new file mode 100644 index 0000000000..3294c90ccd --- /dev/null +++ b/bionemo-recipes/models/esm2/tests/common/run_distributed_dcp.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker script for distributed DCP (Distributed Checkpoint) tests. + +Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init. +Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip. +""" + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from transformers import set_seed + + +def _setup_sys_path(): + """Add model root and tests directory to sys.path so model/test imports work.""" + script_dir = Path(__file__).resolve().parent # tests/common/ + tests_dir = script_dir.parent # tests/ + model_root = tests_dir.parent # model root (e.g., models/esm2/) + for p in [str(model_root), str(tests_dir)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def _load_tester_class(tester_file, class_name): + """Dynamically load a tester class from a file path.""" + # Ensure the tester file's directory tree is importable + tester_dir = str(Path(tester_file).parent) + tester_parent = str(Path(tester_file).parent.parent) + for p in [tester_parent, tester_dir]: + if p not in sys.path: + sys.path.insert(0, p) + + spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + +def _build_and_shard_model(tester, config, recipe, device, device_mesh): + """Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device.""" + model_class = tester.get_model_class() + + if recipe is not None: + with transformer_engine.pytorch.quantized_model_init(recipe=recipe): + model = model_class(config) + else: + model = model_class(config) + + # Shard each transformer layer, then the root model + for layer in tester.get_layer_path(model): + fully_shard(layer, mesh=device_mesh) + fully_shard(model, mesh=device_mesh) + + model.to(device) + return model + + +def _forward(model, input_data, recipe): + """Run a forward pass and return the model outputs.""" + if recipe is not None: + # torch.autocast is needed when model was built with quantized_model_init + # (weights are FP8, non-quantized ops need bf16 casting) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=recipe): + return model(**input_data) + else: + return model(**input_data) + + +def _train_one_step(model, input_data, recipe, lr=1e-4): + """Run a single training step (forward + backward + optimizer step) and return detached logits.""" + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer.zero_grad() + + outputs = _forward(model, input_data, recipe) + loss = outputs.logits.sum() + loss.backward() + optimizer.step() + + return outputs.logits.detach().clone() + + +def _run_eval_forward(model, input_data, recipe): + """Run an eval forward pass and return detached logits.""" + model.eval() + with torch.no_grad(): + outputs = _forward(model, input_data, recipe) + return outputs.logits.detach().clone() + + +def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42): + """Core DCP round-trip test: build → train → save → rebuild → load → eval → compare.""" + from tests.common.fixtures import recipe_from_name + + rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = dist.get_world_size() + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,)) + + # Resolve FP8 recipe + recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None + + # Build config + set_seed(seed) + config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd") + + # Prepare input data + input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32) + + # --- Model A: build, shard, train one step, then eval --- + set_seed(seed) + model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh) + _train_one_step(model_a, input_data, recipe) + logits_a = _run_eval_forward(model_a, input_data, recipe) + + # --- DCP Save --- + # Rank 0 creates temp dir, broadcast path to all ranks + if rank == 0: + tmp_dir = tempfile.mkdtemp(prefix="dcp_test_") + else: + tmp_dir = None + tmp_dir_list = [tmp_dir] + dist.broadcast_object_list(tmp_dir_list, src=0) + tmp_dir = tmp_dir_list[0] + + checkpoint_path = os.path.join(tmp_dir, "checkpoint") + + state_dict_a = {"model": model_a.state_dict()} + dcp.save(state_dict_a, checkpoint_id=checkpoint_path) + + dist.barrier() + + # Free model_a + del model_a, state_dict_a + torch.cuda.empty_cache() + + # --- Model B: build fresh, shard, load, eval --- + set_seed(seed) + model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh) + + state_dict_b = {"model": model_b.state_dict()} + dcp.load(state_dict_b, checkpoint_id=checkpoint_path) + model_b.load_state_dict(state_dict_b["model"], strict=False) + + logits_b = _run_eval_forward(model_b, input_data, recipe) + + # --- Compare --- + tolerances = tester.get_tolerances() + torch.testing.assert_close( + logits_a, + logits_b, + atol=tolerances.dcp_logits_atol, + rtol=tolerances.dcp_logits_rtol, + msg=lambda x: f"DCP round-trip logits mismatch: {x}", + ) + + # Cleanup + del model_b, state_dict_b + torch.cuda.empty_cache() + dist.barrier() + + if rank == 0: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})") + + +if __name__ == "__main__": + _setup_sys_path() + + parser = argparse.ArgumentParser(description="DCP distributed test worker") + parser.add_argument( + "--tester-file", required=True, help="Absolute path to the test file containing the tester class" + ) + parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)") + parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + + try: + tester_cls = _load_tester_class(args.tester_file, args.tester_class) + tester = tester_cls() + run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed) + finally: + dist.destroy_process_group() diff --git a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py index a2cbc336fd..84a9e97b6e 100644 --- a/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/esm2/tests/common/test_modeling_common.py @@ -17,6 +17,8 @@ import fnmatch import gc +import os +import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -39,6 +41,12 @@ HAS_DATA_CENTER_GPU = False +_requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + @dataclass class TestTolerances: """Model-specific test tolerances for numerical comparisons.""" @@ -65,6 +73,10 @@ class TestTolerances: fp8_logits_atol: float = 5.0 fp8_logits_rtol: float = 0.1 + # DCP (distributed checkpoint) round-trip tolerances + dcp_logits_atol: float = 0.0 + dcp_logits_rtol: float = 0.0 + # Meta device initialization tolerances init_mean_atol: float = 1e-3 init_mean_rtol: float = 1e-4 @@ -1099,4 +1111,69 @@ def test_generate_with_cache_beam_search(self): assert output_ids.shape[0] == 2 assert output_ids.shape[1] > inputs["input_ids"].shape[1] - # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. + # ==================== Distributed Checkpoint (DCP) Tests ==================== + + def _get_dcp_worker_script_path(self) -> str: + """Return the absolute path to the run_distributed_dcp.py worker script.""" + return str(Path(__file__).resolve().parent / "run_distributed_dcp.py") + + def _get_tester_file_and_class(self): + """Return (file_path, class_name) for dynamic loading in the worker subprocess.""" + import inspect + + return os.path.abspath(inspect.getfile(type(self))), type(self).__name__ + + def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2): + """Launch the DCP worker script via torchrun and assert it succeeds.""" + tester_file, class_name = self._get_tester_file_and_class() + worker_script = self._get_dcp_worker_script_path() + + cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + worker_script, + "--tester-file", + tester_file, + "--tester-class", + class_name, + ] + + if fp8_recipe_name is not None: + cmd.extend(["--fp8-recipe", fp8_recipe_name]) + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"DCP worker failed with exit code {result.returncode}") + + def test_dcp_output_parity_single_gpu(self, unused_tcp_port): + """Test FSDP2 + DCP save/load round-trip on a single GPU.""" + self._run_dcp_worker(unused_tcp_port, nproc_per_node=1) + + def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port): + """Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1) + + @_requires_multi_gpu + def test_dcp_output_parity(self, unused_tcp_port): + """Test that a model sharded with FSDP2 produces identical outputs after DCP save/load.""" + self._run_dcp_worker(unused_tcp_port) + + @_requires_multi_gpu + def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port): + """Test DCP save/load with FP8 quantized_model_init.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe)) diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index 04246a0d81..049be38353 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -34,9 +34,7 @@ def requires_fp8(func): ) -@pytest.mark.parametrize( - "strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))] -) +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"]) @requires_fp8 def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): cmd = [ @@ -63,9 +61,7 @@ def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port): pytest.fail(f"Command failed with exit code {result.returncode}") -@pytest.mark.parametrize( - "strategy", ["ddp", "fsdp2", pytest.param("mfsdp", marks=pytest.mark.xfail(reason="BIONEMO-2999"))] -) +@pytest.mark.parametrize("strategy", ["ddp", "fsdp2", "mfsdp"]) @requires_fp8 @requires_multi_gpu def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port): diff --git a/bionemo-recipes/models/llama3/tests/common/fixtures.py b/bionemo-recipes/models/llama3/tests/common/fixtures.py index a437aae3d1..000ebd9429 100644 --- a/bionemo-recipes/models/llama3/tests/common/fixtures.py +++ b/bionemo-recipes/models/llama3/tests/common/fixtures.py @@ -102,6 +102,26 @@ def fp8_recipe(request): return request.param +RECIPE_NAME_TO_FACTORY = { + "DelayedScaling": recipe_module.DelayedScaling, + "Float8CurrentScaling": recipe_module.Float8CurrentScaling, + "Float8BlockScaling": recipe_module.Float8BlockScaling, + "MXFP8BlockScaling": recipe_module.MXFP8BlockScaling, + "NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +} + + +def recipe_to_name(recipe): + """Convert a recipe instance to its CLI-passable string name.""" + return type(recipe).__name__ + + +def recipe_from_name(name): + """Reconstruct a recipe instance from its CLI-passable string name.""" + factory = RECIPE_NAME_TO_FACTORY[name] + return factory() + + @pytest.fixture(params=["bshd", "thd"]) def input_format(request): """Fixture to parametrize the input format.""" diff --git a/bionemo-recipes/models/llama3/tests/common/run_distributed_dcp.py b/bionemo-recipes/models/llama3/tests/common/run_distributed_dcp.py new file mode 100644 index 0000000000..3294c90ccd --- /dev/null +++ b/bionemo-recipes/models/llama3/tests/common/run_distributed_dcp.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker script for distributed DCP (Distributed Checkpoint) tests. + +Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init. +Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip. +""" + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from transformers import set_seed + + +def _setup_sys_path(): + """Add model root and tests directory to sys.path so model/test imports work.""" + script_dir = Path(__file__).resolve().parent # tests/common/ + tests_dir = script_dir.parent # tests/ + model_root = tests_dir.parent # model root (e.g., models/esm2/) + for p in [str(model_root), str(tests_dir)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def _load_tester_class(tester_file, class_name): + """Dynamically load a tester class from a file path.""" + # Ensure the tester file's directory tree is importable + tester_dir = str(Path(tester_file).parent) + tester_parent = str(Path(tester_file).parent.parent) + for p in [tester_parent, tester_dir]: + if p not in sys.path: + sys.path.insert(0, p) + + spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + +def _build_and_shard_model(tester, config, recipe, device, device_mesh): + """Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device.""" + model_class = tester.get_model_class() + + if recipe is not None: + with transformer_engine.pytorch.quantized_model_init(recipe=recipe): + model = model_class(config) + else: + model = model_class(config) + + # Shard each transformer layer, then the root model + for layer in tester.get_layer_path(model): + fully_shard(layer, mesh=device_mesh) + fully_shard(model, mesh=device_mesh) + + model.to(device) + return model + + +def _forward(model, input_data, recipe): + """Run a forward pass and return the model outputs.""" + if recipe is not None: + # torch.autocast is needed when model was built with quantized_model_init + # (weights are FP8, non-quantized ops need bf16 casting) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=recipe): + return model(**input_data) + else: + return model(**input_data) + + +def _train_one_step(model, input_data, recipe, lr=1e-4): + """Run a single training step (forward + backward + optimizer step) and return detached logits.""" + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer.zero_grad() + + outputs = _forward(model, input_data, recipe) + loss = outputs.logits.sum() + loss.backward() + optimizer.step() + + return outputs.logits.detach().clone() + + +def _run_eval_forward(model, input_data, recipe): + """Run an eval forward pass and return detached logits.""" + model.eval() + with torch.no_grad(): + outputs = _forward(model, input_data, recipe) + return outputs.logits.detach().clone() + + +def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42): + """Core DCP round-trip test: build → train → save → rebuild → load → eval → compare.""" + from tests.common.fixtures import recipe_from_name + + rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = dist.get_world_size() + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,)) + + # Resolve FP8 recipe + recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None + + # Build config + set_seed(seed) + config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd") + + # Prepare input data + input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32) + + # --- Model A: build, shard, train one step, then eval --- + set_seed(seed) + model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh) + _train_one_step(model_a, input_data, recipe) + logits_a = _run_eval_forward(model_a, input_data, recipe) + + # --- DCP Save --- + # Rank 0 creates temp dir, broadcast path to all ranks + if rank == 0: + tmp_dir = tempfile.mkdtemp(prefix="dcp_test_") + else: + tmp_dir = None + tmp_dir_list = [tmp_dir] + dist.broadcast_object_list(tmp_dir_list, src=0) + tmp_dir = tmp_dir_list[0] + + checkpoint_path = os.path.join(tmp_dir, "checkpoint") + + state_dict_a = {"model": model_a.state_dict()} + dcp.save(state_dict_a, checkpoint_id=checkpoint_path) + + dist.barrier() + + # Free model_a + del model_a, state_dict_a + torch.cuda.empty_cache() + + # --- Model B: build fresh, shard, load, eval --- + set_seed(seed) + model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh) + + state_dict_b = {"model": model_b.state_dict()} + dcp.load(state_dict_b, checkpoint_id=checkpoint_path) + model_b.load_state_dict(state_dict_b["model"], strict=False) + + logits_b = _run_eval_forward(model_b, input_data, recipe) + + # --- Compare --- + tolerances = tester.get_tolerances() + torch.testing.assert_close( + logits_a, + logits_b, + atol=tolerances.dcp_logits_atol, + rtol=tolerances.dcp_logits_rtol, + msg=lambda x: f"DCP round-trip logits mismatch: {x}", + ) + + # Cleanup + del model_b, state_dict_b + torch.cuda.empty_cache() + dist.barrier() + + if rank == 0: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})") + + +if __name__ == "__main__": + _setup_sys_path() + + parser = argparse.ArgumentParser(description="DCP distributed test worker") + parser.add_argument( + "--tester-file", required=True, help="Absolute path to the test file containing the tester class" + ) + parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)") + parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + + try: + tester_cls = _load_tester_class(args.tester_file, args.tester_class) + tester = tester_cls() + run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed) + finally: + dist.destroy_process_group() diff --git a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py index a2cbc336fd..84a9e97b6e 100644 --- a/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/llama3/tests/common/test_modeling_common.py @@ -17,6 +17,8 @@ import fnmatch import gc +import os +import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -39,6 +41,12 @@ HAS_DATA_CENTER_GPU = False +_requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + @dataclass class TestTolerances: """Model-specific test tolerances for numerical comparisons.""" @@ -65,6 +73,10 @@ class TestTolerances: fp8_logits_atol: float = 5.0 fp8_logits_rtol: float = 0.1 + # DCP (distributed checkpoint) round-trip tolerances + dcp_logits_atol: float = 0.0 + dcp_logits_rtol: float = 0.0 + # Meta device initialization tolerances init_mean_atol: float = 1e-3 init_mean_rtol: float = 1e-4 @@ -1099,4 +1111,69 @@ def test_generate_with_cache_beam_search(self): assert output_ids.shape[0] == 2 assert output_ids.shape[1] > inputs["input_ids"].shape[1] - # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. + # ==================== Distributed Checkpoint (DCP) Tests ==================== + + def _get_dcp_worker_script_path(self) -> str: + """Return the absolute path to the run_distributed_dcp.py worker script.""" + return str(Path(__file__).resolve().parent / "run_distributed_dcp.py") + + def _get_tester_file_and_class(self): + """Return (file_path, class_name) for dynamic loading in the worker subprocess.""" + import inspect + + return os.path.abspath(inspect.getfile(type(self))), type(self).__name__ + + def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2): + """Launch the DCP worker script via torchrun and assert it succeeds.""" + tester_file, class_name = self._get_tester_file_and_class() + worker_script = self._get_dcp_worker_script_path() + + cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + worker_script, + "--tester-file", + tester_file, + "--tester-class", + class_name, + ] + + if fp8_recipe_name is not None: + cmd.extend(["--fp8-recipe", fp8_recipe_name]) + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"DCP worker failed with exit code {result.returncode}") + + def test_dcp_output_parity_single_gpu(self, unused_tcp_port): + """Test FSDP2 + DCP save/load round-trip on a single GPU.""" + self._run_dcp_worker(unused_tcp_port, nproc_per_node=1) + + def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port): + """Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1) + + @_requires_multi_gpu + def test_dcp_output_parity(self, unused_tcp_port): + """Test that a model sharded with FSDP2 produces identical outputs after DCP save/load.""" + self._run_dcp_worker(unused_tcp_port) + + @_requires_multi_gpu + def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port): + """Test DCP save/load with FP8 quantized_model_init.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe)) diff --git a/bionemo-recipes/models/mixtral/tests/common/fixtures.py b/bionemo-recipes/models/mixtral/tests/common/fixtures.py index a437aae3d1..000ebd9429 100644 --- a/bionemo-recipes/models/mixtral/tests/common/fixtures.py +++ b/bionemo-recipes/models/mixtral/tests/common/fixtures.py @@ -102,6 +102,26 @@ def fp8_recipe(request): return request.param +RECIPE_NAME_TO_FACTORY = { + "DelayedScaling": recipe_module.DelayedScaling, + "Float8CurrentScaling": recipe_module.Float8CurrentScaling, + "Float8BlockScaling": recipe_module.Float8BlockScaling, + "MXFP8BlockScaling": recipe_module.MXFP8BlockScaling, + "NVFP4BlockScaling": lambda: recipe_module.NVFP4BlockScaling(disable_rht=True, disable_stochastic_rounding=True), +} + + +def recipe_to_name(recipe): + """Convert a recipe instance to its CLI-passable string name.""" + return type(recipe).__name__ + + +def recipe_from_name(name): + """Reconstruct a recipe instance from its CLI-passable string name.""" + factory = RECIPE_NAME_TO_FACTORY[name] + return factory() + + @pytest.fixture(params=["bshd", "thd"]) def input_format(request): """Fixture to parametrize the input format.""" diff --git a/bionemo-recipes/models/mixtral/tests/common/run_distributed_dcp.py b/bionemo-recipes/models/mixtral/tests/common/run_distributed_dcp.py new file mode 100644 index 0000000000..3294c90ccd --- /dev/null +++ b/bionemo-recipes/models/mixtral/tests/common/run_distributed_dcp.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Worker script for distributed DCP (Distributed Checkpoint) tests. + +Launched by torchrun from BaseModelTest.test_dcp_output_parity / test_dcp_output_parity_fp8_init. +Verifies that a model sharded with FSDP2 produces identical outputs after a DCP save/load round-trip. +""" + +import argparse +import importlib.util +import os +import shutil +import sys +import tempfile +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dcp +import transformer_engine.pytorch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import fully_shard +from transformers import set_seed + + +def _setup_sys_path(): + """Add model root and tests directory to sys.path so model/test imports work.""" + script_dir = Path(__file__).resolve().parent # tests/common/ + tests_dir = script_dir.parent # tests/ + model_root = tests_dir.parent # model root (e.g., models/esm2/) + for p in [str(model_root), str(tests_dir)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def _load_tester_class(tester_file, class_name): + """Dynamically load a tester class from a file path.""" + # Ensure the tester file's directory tree is importable + tester_dir = str(Path(tester_file).parent) + tester_parent = str(Path(tester_file).parent.parent) + for p in [tester_parent, tester_dir]: + if p not in sys.path: + sys.path.insert(0, p) + + spec = importlib.util.spec_from_file_location("_dcp_tester_module", tester_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + +def _build_and_shard_model(tester, config, recipe, device, device_mesh): + """Build a model (optionally with FP8 quantized_model_init), shard with FSDP2, and move to device.""" + model_class = tester.get_model_class() + + if recipe is not None: + with transformer_engine.pytorch.quantized_model_init(recipe=recipe): + model = model_class(config) + else: + model = model_class(config) + + # Shard each transformer layer, then the root model + for layer in tester.get_layer_path(model): + fully_shard(layer, mesh=device_mesh) + fully_shard(model, mesh=device_mesh) + + model.to(device) + return model + + +def _forward(model, input_data, recipe): + """Run a forward pass and return the model outputs.""" + if recipe is not None: + # torch.autocast is needed when model was built with quantized_model_init + # (weights are FP8, non-quantized ops need bf16 casting) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with transformer_engine.pytorch.autocast(recipe=recipe): + return model(**input_data) + else: + return model(**input_data) + + +def _train_one_step(model, input_data, recipe, lr=1e-4): + """Run a single training step (forward + backward + optimizer step) and return detached logits.""" + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer.zero_grad() + + outputs = _forward(model, input_data, recipe) + loss = outputs.logits.sum() + loss.backward() + optimizer.step() + + return outputs.logits.detach().clone() + + +def _run_eval_forward(model, input_data, recipe): + """Run an eval forward pass and return detached logits.""" + model.eval() + with torch.no_grad(): + outputs = _forward(model, input_data, recipe) + return outputs.logits.detach().clone() + + +def run_dcp_output_parity(tester, fp8_recipe_name=None, seed=42): + """Core DCP round-trip test: build → train → save → rebuild → load → eval → compare.""" + from tests.common.fixtures import recipe_from_name + + rank = dist.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = dist.get_world_size() + device = f"cuda:{local_rank}" + torch.cuda.set_device(local_rank) + + device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,)) + + # Resolve FP8 recipe + recipe = recipe_from_name(fp8_recipe_name) if fp8_recipe_name else None + + # Build config + set_seed(seed) + config = tester.create_test_config(dtype=torch.bfloat16, attn_input_format="bshd") + + # Prepare input data + input_data = tester.get_test_input_data("bshd", pad_to_multiple_of=32) + + # --- Model A: build, shard, train one step, then eval --- + set_seed(seed) + model_a = _build_and_shard_model(tester, config, recipe, device, device_mesh) + _train_one_step(model_a, input_data, recipe) + logits_a = _run_eval_forward(model_a, input_data, recipe) + + # --- DCP Save --- + # Rank 0 creates temp dir, broadcast path to all ranks + if rank == 0: + tmp_dir = tempfile.mkdtemp(prefix="dcp_test_") + else: + tmp_dir = None + tmp_dir_list = [tmp_dir] + dist.broadcast_object_list(tmp_dir_list, src=0) + tmp_dir = tmp_dir_list[0] + + checkpoint_path = os.path.join(tmp_dir, "checkpoint") + + state_dict_a = {"model": model_a.state_dict()} + dcp.save(state_dict_a, checkpoint_id=checkpoint_path) + + dist.barrier() + + # Free model_a + del model_a, state_dict_a + torch.cuda.empty_cache() + + # --- Model B: build fresh, shard, load, eval --- + set_seed(seed) + model_b = _build_and_shard_model(tester, config, recipe, device, device_mesh) + + state_dict_b = {"model": model_b.state_dict()} + dcp.load(state_dict_b, checkpoint_id=checkpoint_path) + model_b.load_state_dict(state_dict_b["model"], strict=False) + + logits_b = _run_eval_forward(model_b, input_data, recipe) + + # --- Compare --- + tolerances = tester.get_tolerances() + torch.testing.assert_close( + logits_a, + logits_b, + atol=tolerances.dcp_logits_atol, + rtol=tolerances.dcp_logits_rtol, + msg=lambda x: f"DCP round-trip logits mismatch: {x}", + ) + + # Cleanup + del model_b, state_dict_b + torch.cuda.empty_cache() + dist.barrier() + + if rank == 0: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"[Rank {rank}] DCP output parity test PASSED (fp8_recipe={fp8_recipe_name})") + + +if __name__ == "__main__": + _setup_sys_path() + + parser = argparse.ArgumentParser(description="DCP distributed test worker") + parser.add_argument( + "--tester-file", required=True, help="Absolute path to the test file containing the tester class" + ) + parser.add_argument("--tester-class", required=True, help="Name of the tester class (e.g., TestESM2Model)") + parser.add_argument("--fp8-recipe", default=None, help="FP8 recipe name (e.g., DelayedScaling)") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + + try: + tester_cls = _load_tester_class(args.tester_file, args.tester_class) + tester = tester_cls() + run_dcp_output_parity(tester, fp8_recipe_name=args.fp8_recipe, seed=args.seed) + finally: + dist.destroy_process_group() diff --git a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py index a2cbc336fd..84a9e97b6e 100644 --- a/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py +++ b/bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py @@ -17,6 +17,8 @@ import fnmatch import gc +import os +import subprocess from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -39,6 +41,12 @@ HAS_DATA_CENTER_GPU = False +_requires_multi_gpu = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Test requires at least 2 GPUs", +) + + @dataclass class TestTolerances: """Model-specific test tolerances for numerical comparisons.""" @@ -65,6 +73,10 @@ class TestTolerances: fp8_logits_atol: float = 5.0 fp8_logits_rtol: float = 0.1 + # DCP (distributed checkpoint) round-trip tolerances + dcp_logits_atol: float = 0.0 + dcp_logits_rtol: float = 0.0 + # Meta device initialization tolerances init_mean_atol: float = 1e-3 init_mean_rtol: float = 1e-4 @@ -1099,4 +1111,69 @@ def test_generate_with_cache_beam_search(self): assert output_ids.shape[0] == 2 assert output_ids.shape[1] > inputs["input_ids"].shape[1] - # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc. + # ==================== Distributed Checkpoint (DCP) Tests ==================== + + def _get_dcp_worker_script_path(self) -> str: + """Return the absolute path to the run_distributed_dcp.py worker script.""" + return str(Path(__file__).resolve().parent / "run_distributed_dcp.py") + + def _get_tester_file_and_class(self): + """Return (file_path, class_name) for dynamic loading in the worker subprocess.""" + import inspect + + return os.path.abspath(inspect.getfile(type(self))), type(self).__name__ + + def _run_dcp_worker(self, unused_tcp_port, fp8_recipe_name=None, nproc_per_node=2): + """Launch the DCP worker script via torchrun and assert it succeeds.""" + tester_file, class_name = self._get_tester_file_and_class() + worker_script = self._get_dcp_worker_script_path() + + cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--rdzv-backend=c10d", + f"--rdzv-endpoint=localhost:{unused_tcp_port}", + worker_script, + "--tester-file", + tester_file, + "--tester-class", + class_name, + ] + + if fp8_recipe_name is not None: + cmd.extend(["--fp8-recipe", fp8_recipe_name]) + + result = subprocess.run( + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=300, + ) + if result.returncode != 0: + print(f"STDOUT:\n{result.stdout}") + print(f"STDERR:\n{result.stderr}") + pytest.fail(f"DCP worker failed with exit code {result.returncode}") + + def test_dcp_output_parity_single_gpu(self, unused_tcp_port): + """Test FSDP2 + DCP save/load round-trip on a single GPU.""" + self._run_dcp_worker(unused_tcp_port, nproc_per_node=1) + + def test_dcp_output_parity_fp8_init_single_gpu(self, fp8_recipe, unused_tcp_port): + """Test FSDP2 + DCP save/load with FP8 quantized_model_init on a single GPU.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe), nproc_per_node=1) + + @_requires_multi_gpu + def test_dcp_output_parity(self, unused_tcp_port): + """Test that a model sharded with FSDP2 produces identical outputs after DCP save/load.""" + self._run_dcp_worker(unused_tcp_port) + + @_requires_multi_gpu + def test_dcp_output_parity_fp8_init(self, fp8_recipe, unused_tcp_port): + """Test DCP save/load with FP8 quantized_model_init.""" + from .fixtures import recipe_to_name + + self._run_dcp_worker(unused_tcp_port, fp8_recipe_name=recipe_to_name(fp8_recipe))