From d714896bfcdb386d1ad9b9b1cd841542695f659f Mon Sep 17 00:00:00 2001 From: Xibin Liu Date: Mon, 19 Jan 2026 19:15:02 +0000 Subject: [PATCH 1/3] NNX migration preparation: pure_nnx flag and init_state_fn - pure_nnx: a flag to to choose pure NNX logic when NNX and linen models co-exist. - init_state_fn: a function to initialize the model state for the training. It will be set to different function for NNX and Linen. --- .../convert_gpt3_ckpt_from_paxml.py | 13 +- src/maxtext/configs/base.yml | 5 +- src/maxtext/configs/types.py | 1 + src/maxtext/experimental/rl/grpo_trainer.py | 32 +- src/maxtext/inference/maxengine/maxengine.py | 21 +- src/maxtext/layers/train_state_nnx.py | 48 ++ .../post_train/sft/train_sft_deprecated.py | 2 +- src/maxtext/trainers/pre_train/train.py | 8 +- .../trainers/pre_train/train_compile.py | 64 ++- .../utils/generate_param_only_checkpoint.py | 20 +- src/maxtext/utils/layerwise_quantization.py | 20 +- src/maxtext/utils/lora_utils.py | 8 +- src/maxtext/utils/maxtext_utils.py | 45 +- src/maxtext/utils/maxtext_utils_nnx.py | 172 +++++++ src/maxtext/utils/model_creation_utils.py | 32 ++ src/maxtext/utils/train_utils.py | 49 +- .../generate_grpo_golden_logits.py | 28 +- .../integration/grpo_correctness.py | 12 +- .../grpo_trainer_correctness_test.py | 10 +- .../sft_trainer_correctness_test.py | 10 +- tests/unit/maxtext_utils_test.py | 457 +++++++++++++++++- tests/unit/model_creation_utils_test.py | 326 +++++++++++++ tests/unit/sharding_compare_test.py | 13 +- tests/unit/state_dtypes_test.py | 8 +- tests/unit/train_utils_test.py | 196 ++++++++ tests/utils/forward_pass_logit_checker.py | 19 +- .../gcs_benchmarks/standalone_checkpointer.py | 47 +- tools/gcs_benchmarks/standalone_dataloader.py | 4 +- 28 files changed, 1532 insertions(+), 138 deletions(-) create mode 100644 src/maxtext/layers/train_state_nnx.py create mode 100644 src/maxtext/utils/maxtext_utils_nnx.py create mode 100644 tests/unit/model_creation_utils_test.py create mode 100644 tests/unit/train_utils_test.py diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 888cf4d2d1..9b5f0cfb21 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -35,6 +35,7 @@ """ import argparse +import functools import gc import os import sys @@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name mesh = Mesh(devices_array, cfg.mesh_axes) quant = quantizations.configure_quantization(cfg) - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if cfg.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) + if cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 477ee223dc..67297751a8 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1115,8 +1115,9 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false -pure_nnx_decoder: false +enable_nnx: False +pure_nnx_decoder: False +pure_nnx: False ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index ed4836571b..69fe932be0 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -802,6 +802,7 @@ class HardwareAndMesh(BaseModel): optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") + pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.") class LayoutAndSharding(BaseModel): diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 100434ef74..28eef21cb0 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -546,23 +546,43 @@ def setup_train_loop( max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - model = mt.from_config(config, devices=training_devices) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - inference_model = mt.from_config(config_inference, devices=inference_devices) + if config_inference.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + data_iterator, config, mesh, checkpoint_manager, init_state_fn ) # create inference_state_mesh_shardings from inference_mesh + if config_inference.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( - inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False + config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage @@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None): data_buffer = [] data_buffer_lock = threading.Lock() - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) inference_prof = profiler.Profiler(config_inference, offset_step=start_step) data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 02a2f392c2..23cd2387db 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( - self.model, None, self.config, rng, self._mesh, False + self.config, self._mesh, init_state_fn, False ) # reshard given params based on shardings from config in MaxEngine params = jax.device_put(params, state_mesh_shardings.params) state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - state, self.state_mesh_annotations = maxtext_utils.setup_decode_state( - self.model, self.config, rng1, self._mesh, None - ) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py new file mode 100644 index 0000000000..9ef0e6dffd --- /dev/null +++ b/src/maxtext/layers/train_state_nnx.py @@ -0,0 +1,48 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" The NNX Unified TrainState. """ + +from typing import Any + +from flax import nnx + + +class TrainStateNNX(nnx.Module): + """ + A unified container for NNX models and optimizers. + This replaces Linen's TrainState for checkpointing. + + Linen TrainState pytree: + {“params”: {...}, “opt_state”: {}...} + TrainStateNNX state pytree: + {“model”: {...}, “optimizer”: {“opt_state”: {...}} + """ + + def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + self.model = model + self.optimizer = optimizer + + def apply_gradients(self, grads: Any): + """ + Mimics the Linen apply_gradients function. + Updates the optimizer state, applies updates to parameters, + and increments the step counter. + """ + if self.optimizer is None: + raise RuntimeError( + "Cannot call apply_gradients on a TrainStateNNX initialized without an optimizer. " + "This usually happens when the state was created for inference only." + ) + self.optimizer.update(self.model, grads) diff --git a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py index 7cc8f5b658..c7f6bd4740 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py +++ b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py @@ -85,7 +85,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) data_loader = DataLoader(config, mesh, data_iterator, recorder) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index a3c39acb9f..c2f32f076c 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -75,8 +75,10 @@ VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() -def get_first_step(state): - return int(state.step) +def get_first_step(model, state): + if isinstance(model, nn.Module): + return int(state.step) + return int(state.optimizer.step.get_value()) # ----------------------------------------------------------------------------- @@ -528,7 +530,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 74f36ea045..9462655761 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -27,6 +27,7 @@ from typing import Sequence from absl import app +from flax import nnx from flax.linen import partitioning as nn_partitioning import jax from jax.experimental.serialize_executable import serialize @@ -36,6 +37,7 @@ from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations +from maxtext.layers import train_state_nnx from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.trainers.diloco import diloco @@ -44,6 +46,8 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils # pylint: disable=too-many-positional-arguments @@ -93,7 +97,10 @@ def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) - model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if config.pure_nnx: + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) + else: + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon @@ -103,18 +110,39 @@ def get_shaped_inputs(topology_mesh, config): _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) - # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, example_rng, topology_mesh - ) + if config.pure_nnx: + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) - # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) + # Shaped state + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True) + + if config.pure_nnx: + # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings. + logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + # For NNX, get_functional_train_with_signature expects the graphdef (static structure), + # not the raw model — mirroring how the training loop does nnx.split(train_state). + with nn_partitioning.axis_rules(config.logical_axis_rules): + graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + model = graphdef + else: + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) - shaped_train_args = (abstract_state, shaped_batch, shaped_rng) + if config.pure_nnx: + shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + else: + shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -279,12 +307,20 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(topology_mesh) - maxtext_utils.print_shardings_params( - shaped_train_args[0].params, - state_mesh_shardings.params, - topology_mesh, - logical_annotations.params, - ) + if config.pure_nnx: + maxtext_utils.print_shardings_params( + shaped_train_args[0], + state_mesh_shardings, + topology_mesh, + logical_annotations, + ) + else: + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + logical_annotations.params, + ) # Compile print("Jitting and compiling train step...", flush=True) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 7c520cc470..2fd14b87a2 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -22,6 +22,7 @@ The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16. """ +import functools import os.path from typing import Sequence @@ -42,8 +43,6 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -Transformer = models.transformer_as_linen - def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): """Unroll scanned input layers when force_unroll is set.""" @@ -93,12 +92,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( - model, None, tx, config, rng, mesh, checkpoint_manager + None, config, mesh, checkpoint_manager, init_state_fn ) num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") @@ -109,7 +116,10 @@ def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 4be05ff7e1..36e612a3f9 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -30,6 +30,7 @@ """ +import functools import os from typing import Any, Sequence @@ -174,12 +175,19 @@ def __init__(self, config: Any, rng: PRNGKeyType): # Model and quantization config self.quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, None, self.config, self.rng, self._mesh, False - ) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + + self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) def load_and_quantize(self) -> None: """ diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 03095edd73..24099ef22a 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -14,6 +14,7 @@ """ Common LoRA utils needed to support LoRA adapters.""" +from functools import partial import json import jax @@ -166,7 +167,12 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, rng, mesh, True) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index aa775c155e..35ba11c389 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -196,8 +196,11 @@ def get_train_input_output_trees(func, input_args, input_kwargs): serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file) shaped_batch = get_shaped_batch(config) - example_rng = jax.random.PRNGKey(0) - shaped_input_args = (state, shaped_batch, example_rng) + if config.pure_nnx: + shaped_input_args = (state, shaped_batch) + else: + example_rng = jax.random.PRNGKey(0) + shaped_input_args = (state, shaped_batch, example_rng) shaped_input_kwargs = {} in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs) p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree, execution_devices=execution_devices) @@ -1050,14 +1053,13 @@ def get_abstract_param(model, config): return abstract_vars -def setup_decode_state(model, config, rng, mesh, checkpoint_manager): +def setup_decode_state(config, mesh, checkpoint_manager, init_state_fn): """Setup decode state by loading params from a checkpoint. Args: - model: the flax model to initialize config: config object - rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: Checkpoint manager + init_state_fn: function to initialize the model state Returns: state: state with decode params loaded from the checkpoint @@ -1067,12 +1069,12 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): # generate random params max_logging.log("No decode checkpoint specified - generating random weights.") state, state_mesh_annotations, _, _ = setup_initial_state( - model, None, None, config, rng, mesh, checkpoint_manager, False + None, config, mesh, checkpoint_manager, init_state_fn, False ) else: # Load params from checkpoint max_logging.log(f"Loading decode params from {config.load_parameters_path}") - unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False) + unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(config, mesh, init_state_fn, False) with nn_partitioning.axis_rules(config.logical_axis_rules): params = checkpointing.load_params_from_path( config.load_parameters_path, @@ -1087,40 +1089,35 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): return state, state_mesh_annotations -def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): +def setup_training_state(data_iterator, config, mesh, checkpoint_manager, init_state_fn): is_training = True return setup_initial_state( - model, data_iterator, - tx, config, - rng, mesh, checkpoint_manager, + init_state_fn, is_training, ) def setup_initial_state( - model, data_iterator, - tx, config, - rng, mesh, checkpoint_manager, + init_state_fn, is_training=True, ): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: - model: the flax model to initialize - tx: the optax.GradientTransformation + data_iterator: data iterator config: config object - rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: an Orbax checkpointing.CheckpointManager object + init_state_fn: function to initialize the training state is_training: True to initialize training state, False for decode state Returns: @@ -1129,7 +1126,7 @@ def setup_initial_state( """ unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( - model, tx, config, rng, mesh, is_training + config, mesh, init_state_fn, is_training ) # Initialization @@ -1164,14 +1161,14 @@ def setup_initial_state( # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] else: - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" # pylint: disable=not-callable state = jax.jit( init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings, - )(rng) + )() if raw_params: # If we loaded a partial state, we need to merge it. state = state.replace(params=raw_params) @@ -1180,8 +1177,8 @@ def setup_initial_state( return state, state_mesh_annotations, state_mesh_shardings, data_iterator -def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) +def get_logical_annotations(config, mesh, init_state_fn): + init_state_partial = init_state_fn with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) @@ -1189,9 +1186,9 @@ def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): return logical_annotations -def get_abstract_state(model, tx, config, rng, mesh, is_training=True): +def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py new file mode 100644 index 0000000000..7378928ef2 --- /dev/null +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -0,0 +1,172 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Utils for MaxText NNX. """ + +from functools import partial +from typing import Callable + +from flax import nnx +import jax +from jax.sharding import Mesh, NamedSharding + +from maxtext.utils import max_logging +from maxtext.configs import pyconfig + + +def create_nnx_rngs( + config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None +) -> nnx.Rngs: + """ + Create NNX Rngs + + Args: + config: the configuration + is_training: if the Rngs are for training + rng_key: the Rng key + + Returns: + The NNX Rngs + """ + if rng_key is None: + rng_key = jax.random.PRNGKey(config.init_weights_seed) + + if is_training: + return nnx.Rngs( + params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2) + ) + return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference + + +def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State: + """Get named sharding from NNX abstract state. + + Args: + abstract_state: NNX model abstract state created from nnx.get_abstract_model. + + Returns: + named sharding structure + """ + # Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we + # get the existing sharding from the abstract_state. + # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) + return jax.tree.map( + lambda x: x.sharding, + abstract_state, + is_leaf=lambda x: isinstance(x, jax.ShapeDtypeStruct), + ) + + +def get_partition_spec_nnx(named_sharding: nnx.State) -> nnx.State: + """Get mesh partition spec from named sharding. + + Args: + named_sharding: NNX model named sharding. + + Returns: + mesh partition spec + """ + # The leaf is of type NamedSharding. + return jax.tree.map( + lambda x: x.spec, + named_sharding, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + + +def set_named_sharding_nnx(abstract_state: nnx.State, named_sharding: nnx.State) -> nnx.State: + """Set named sharding to NNX abstract state. + + Args: + abstract_state: NNX model abstract state created from nnx.get_abstract_model(). + named_sharding: named sharding. It must have the same tree structure with abstract_state. + + Returns: + updated abstract_state + """ + return jax.tree.map(lambda x, y: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=y), abstract_state, named_sharding) + + +def move_memory_to_host(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: + """ + Change the memory_kind of the NamedSharding to "pinned_host". This function can be + called by jax.tree_util.tree_map_with_path on a NNX state structure. + + Args: + path: the tree path tuple + x: the NamedSharding corresponding to the path + + Returns: + the NamedSharding with memory_kind set to "pinned_host" + """ + max_logging.log(f"max_utils.py: Moving {path} to host") + # Create the new sharding with the target memory kind + return x.with_memory_kind(kind="pinned_host") + + +def move_memory_to_device(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: + """ + Change the memory_kind of the NamedSharding to "device". This function can be + called by jax.tree_util.tree_map_with_path on a NNX state structure. + + Args: + path: the tree path tuple + x: the NamedSharding corresponding to the path + + Returns: + the NamedSharding with memory_kind set to "device" + """ + max_logging.log(f"max_utils.py: Moving {path} to device") + # Create the new sharding with the target memory kind + return x.with_memory_kind(kind="device") + + +def create_nnx_sharded_model( + abstract_model: nnx.Module, + init_fn: Callable, + mesh: Mesh | None = None, + named_sharding: nnx.State | None = None, +) -> nnx.Module: + """ + Create the model with the given sharding. + + Args: + abstract_model: the abstract model + init_fn: the model init function + mesh: the device mesh + named_sharding: the given sharding + + Returns: + The initialized sharded model + """ + graphdef, abstract_state = nnx.split(abstract_model) + if named_sharding is None: + # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) + # we get the sharding directly from it. + named_sharding = get_named_sharding_nnx(abstract_state) + + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + @partial(jax.jit, out_shardings=named_sharding) + def create_sharded_state(): + model = init_fn() + return jax.lax.with_sharding_constraint(nnx.state(model), named_sharding) + + # Create the model with sharded parameters. + with jax.set_mesh(mesh): + sharded_state = create_sharded_state() + return nnx.merge(graphdef, sharded_state) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index b3057d0518..a010aac993 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -112,6 +112,38 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model +def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Returns (_create_model_partial, abstract_model) for AOT compilation. + + Unlike create_nnx_model, this does not shard parameters or load checkpoints. + It only builds the abstract shape/dtype structure needed by get_abstract_state + and optimizer construction (e.g. Muon). + + Args: + config: the configuration + mesh: the device mesh + model_mode: train or inference + rng_key: optional RNG key + + Returns: + (_create_model_partial, abstract_model) where _create_model_partial() creates + a concrete model instance and abstract_model is the eval_shape result. + """ + + def _create_model(rng_key=None): + if rng_key is None: + rng_key = jax.random.PRNGKey(config.init_weights_seed) + rngs = nnx.Rngs(params=rng_key, dropout=1) + return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) + + _create_model_partial = partial(_create_model, rng_key=rng_key) + + with nn.logical_axis_rules(config.logical_axis_rules): + abstract_model = nnx.eval_shape(_create_model_partial) + + return _create_model_partial, abstract_model + + def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 54b2755801..9413b099ed 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -16,6 +16,8 @@ """Utils that are only interesting for training in MaxText.""" import os +from functools import partial + import jax import functools from flax.linen import partitioning as nn_partitioning @@ -33,12 +35,17 @@ from maxtext.trainers.diloco import diloco -def create_training_tools(config, model, mesh): - """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" - init_rng = jax.random.PRNGKey(config.init_weights_seed) +def create_training_optimizer(config, model): + """Creates the optimizer and learning rate schedule.""" learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon tx = optimizers.get_optimizer(config, learning_rate_schedule, model) + return learning_rate_schedule, tx + + +def create_checkpoint_manager(config, mesh, init_state_fn): + """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" + # pass in model for muon logger = checkpointing.setup_checkpoint_logger(config) if config.enable_multi_tier_checkpointing: checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( @@ -47,7 +54,7 @@ def create_training_tools(config, model, mesh): mesh, ) elif config.enable_emergency_checkpoint: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( config.local_checkpoint_directory, config.checkpoint_dir, @@ -85,10 +92,10 @@ def create_training_tools(config, model, mesh): config.enable_autocheckpoint, ) - return init_rng, checkpoint_manager, learning_rate_schedule, tx + return checkpoint_manager -def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): +def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=None): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" if config.enable_diloco: functional_train = train_step @@ -110,7 +117,9 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit if config.compiled_trainstep_file != "": max_logging.log("Loading the compiled function...") - execution_devices = model.mesh.devices.flatten().tolist() + # For NNX, model is the GraphDef (no .mesh); use the mesh passed explicitly instead. + execution_mesh = mesh if mesh is not None else model.mesh + execution_devices = execution_mesh.devices.flatten().tolist() # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state, execution_devices) max_logging.log("Loaded compiled function!") @@ -165,7 +174,9 @@ def jit_train_and_eval_step( train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh) data_sharding = sharding.get_input_data_sharding(config, mesh) - p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) + p_train_step = jit_train_step( + config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh + ) p_eval_step = None if eval_data_iterator: p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) @@ -197,9 +208,21 @@ def setup_train_loop(config, recorder, devices=None): from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): - model = model_creation_utils.from_config(config, devices) + is_training = True + init_rng = jax.random.PRNGKey(config.init_weights_seed) + if config.pure_nnx: + # Create abstract NNX model. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = model_creation_utils.from_config(config, devices) mesh = model.mesh - init_rng, checkpoint_manager, learning_rate_schedule, tx = create_training_tools(config, model, mesh) + learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) + checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) @@ -225,7 +248,7 @@ def setup_train_loop(config, recorder, devices=None): ) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + data_iterator, config, mesh, checkpoint_manager, init_state_fn ) if config.enable_diloco: @@ -248,14 +271,14 @@ def setup_train_loop(config, recorder, devices=None): # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) max_utils.print_non_trivial_mesh_axis(model.mesh) maxtext_utils.print_shardings_params( state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params ) if config.use_dpo: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) diff --git a/tests/assets/logits_generation/generate_grpo_golden_logits.py b/tests/assets/logits_generation/generate_grpo_golden_logits.py index e4e9f4fe8a..cae8b9e4d3 100644 --- a/tests/assets/logits_generation/generate_grpo_golden_logits.py +++ b/tests/assets/logits_generation/generate_grpo_golden_logits.py @@ -38,7 +38,7 @@ from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.utils import maxtext_utils -from tests.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs +from tests.post_training.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs import numpy as np import torch import transformers @@ -73,17 +73,27 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) # With checkpoint - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) + if self.cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) + self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules) self.data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # Without checkpoint - self.model_no_ckpt_loading = models.transformer_as_linen( - config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN - ) - self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state( - self.model_no_ckpt_loading, self.cfg_no_ckpt_loading, self.rng, mesh, None - ) + if self.cfg_no_ckpt_loading.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model_no_ckpt_loading = models.transformer_as_linen( + config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial( + maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng + ) + self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state(self.cfg_no_ckpt_loading, mesh, None, init_state_fn) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", diff --git a/tests/post_training/integration/grpo_correctness.py b/tests/post_training/integration/grpo_correctness.py index 44a3e28df7..adefc03a7e 100644 --- a/tests/post_training/integration/grpo_correctness.py +++ b/tests/post_training/integration/grpo_correctness.py @@ -13,6 +13,7 @@ # limitations under the License. """GRPO correctness tests""" +import functools import os import unittest @@ -60,8 +61,13 @@ def setUp(self): self.rng = jax.random.PRNGKey(42) devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) + if self.cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) + self.state, _ = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", add_bos_token=False, @@ -121,7 +127,7 @@ def _prepare_maxtext_inputs(self): ) def _prepare_trl_inputs(self): - """Prepare TRL inputs.""" + """Prepare inputs for TRL model.""" tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt") input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1) attention_mask = torch.cat( diff --git a/tests/post_training/integration/grpo_trainer_correctness_test.py b/tests/post_training/integration/grpo_trainer_correctness_test.py index 9a2cfd4078..b880a0e678 100644 --- a/tests/post_training/integration/grpo_trainer_correctness_test.py +++ b/tests/post_training/integration/grpo_trainer_correctness_test.py @@ -25,6 +25,7 @@ pytest tests/post_training/integration/grpo_trainer_correctness_test.py """ +import functools import os import subprocess import sys @@ -72,8 +73,13 @@ def setup_maxtext_model(config, mesh): init_rng = jax.random.PRNGKey(config.init_weights_seed) quant = quantizations.configure_quantization(config) - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, state_mesh_annotations = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) + state, state_mesh_annotations = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, config.logical_axis_rules) data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) reference_params = jax.tree.map(jnp.copy, state.params["params"]) diff --git a/tests/post_training/integration/sft_trainer_correctness_test.py b/tests/post_training/integration/sft_trainer_correctness_test.py index beeb2036d9..89ac19d0f3 100644 --- a/tests/post_training/integration/sft_trainer_correctness_test.py +++ b/tests/post_training/integration/sft_trainer_correctness_test.py @@ -24,6 +24,7 @@ pytest tests/post_training/integration/sft_trainer_correctness_test.py """ +import functools import os.path import subprocess import sys @@ -117,8 +118,13 @@ def setup_maxtext_model(config): quant = quantizations.configure_quantization(config) devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) return maxtext_model, state, init_rng diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index a65a905c7f..4850e972b3 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -14,8 +14,9 @@ """Tests for the common MaxText utilities""" -from collections.abc import Callable +import functools from typing import Any +from collections.abc import Callable import unittest from unittest.mock import MagicMock, Mock @@ -28,7 +29,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models @@ -351,18 +352,31 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) quant = quantizations.configure_quantization(self.config) - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) def test_setup_decode_state(self): rng = random.PRNGKey(0) - state, _ = maxtext_utils.setup_decode_state(self.model, self.config, rng, self.mesh, None) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + state, _ = maxtext_utils.setup_decode_state(self.config, self.mesh, None, init_state_fn) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) - state, _, _, _ = maxtext_utils.setup_initial_state(self.model, None, tx, self.config, rng, self.mesh, None) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) + state, _, _, _ = maxtext_utils.setup_initial_state(None, self.config, self.mesh, None, init_state_fn) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) @@ -931,7 +945,8 @@ def setUp(self): def test_get_abstract_state(self): """Tests that get_abstract_state returns abstract arrays.""" # get_abstract_state returns a tuple, the first element is the abstract state. - abstract_state, _, _ = maxtext_utils.get_abstract_state(self.model, self.tx, self.config, self.rng, self.mesh, None) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) + abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self.mesh, init_state_fn) # Check that params are abstract param_leaves = jax.tree_util.tree_leaves(abstract_state.params) @@ -942,5 +957,435 @@ def test_get_abstract_state(self): self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves)) +class TestGetFunctionalTrainWithSignature(unittest.TestCase): + """Tests for get_functional_train_with_signature.""" + + def _make_mock_step(self): + def train_step(_model, _config, _state_shardings, _params_shardings, state, _batch, _rng=None): + return state, {} + + return train_step + + def test_returns_five_tuple(self): + step = self._make_mock_step() + result = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", "config" + ) + self.assertEqual(len(result), 5) + + def test_functional_train_has_correct_name(self): + step = self._make_mock_step() + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", "config" + ) + self.assertEqual(fn.__name__, "train_step") + + def test_in_shardings_structure(self): + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", "config" + ) + # (state, batch, rng) + self.assertEqual(len(in_shardings), 3) + self.assertIsNone(in_shardings[2]) # rng sharding is None + + def test_donate_argnums_is_zero(self): + step = self._make_mock_step() + _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", "config" + ) + self.assertEqual(donate_argnums, 0) + + def test_functional_train_is_partial(self): + """functional_train should partially apply model and config.""" + received = {} + + def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): + received["model"] = model + received["config"] = config + return state, {} + + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") + fn("state", "batch") + self.assertEqual(received["model"], "my_model") + self.assertEqual(received["config"], "my_config") + + +class TestGetFunctionalEvalWithSignature(unittest.TestCase): + """Tests for get_functional_eval_with_signature.""" + + def _make_mock_eval_step(self): + def eval_step(_model, _config, _state, _batch, _rng=None): + return {} + + return eval_step + + def test_returns_five_tuple(self): + step = self._make_mock_eval_step() + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + self.assertEqual(len(result), 5) + + def test_functional_eval_has_correct_name(self): + step = self._make_mock_eval_step() + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + self.assertEqual(fn.__name__, "eval_step") + + def test_out_shardings_is_none(self): + step = self._make_mock_eval_step() + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + self.assertIsNone(out_shardings) + + def test_donate_argnums_is_empty(self): + step = self._make_mock_eval_step() + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + self.assertEqual(donate_argnums, ()) + + +class TestGetShapedBatch(unittest.TestCase): + """Tests for get_shaped_batch.""" + + def _make_cfg(self, *, enable_diloco=False, use_multimodal=False, use_audio=False): + cfg = MagicMock() + cfg.enable_diloco = enable_diloco + cfg.global_batch_size_to_load = 4 + cfg.max_target_length = 16 + cfg.use_multimodal = use_multimodal + cfg.use_audio = use_audio + if enable_diloco: + cfg.num_diloco_replicas = 2 + return cfg + + def test_standard_keys_present(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg()) + for key in ( + "inputs", + "inputs_position", + "inputs_segmentation", + "targets", + "targets_position", + "targets_segmentation", + ): + self.assertIn(key, batch) + + def test_standard_shape(self): + cfg = self._make_cfg() + batch = maxtext_utils.get_shaped_batch(cfg) + expected_shape = (cfg.global_batch_size_to_load, cfg.max_target_length) + self.assertEqual(batch["inputs"].shape, expected_shape) + + def test_diloco_shape(self): + cfg = self._make_cfg(enable_diloco=True) + batch = maxtext_utils.get_shaped_batch(cfg) + expected_shape = ( + cfg.num_diloco_replicas, + cfg.global_batch_size_to_load // cfg.num_diloco_replicas, + cfg.max_target_length, + ) + self.assertEqual(batch["inputs"].shape, expected_shape) + + def test_no_image_key_without_multimodal(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg(use_multimodal=False)) + self.assertNotIn("images", batch) + + def test_no_audio_key_without_audio(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg(use_audio=False)) + self.assertNotIn("audios", batch) + + def test_all_values_are_shape_dtype_struct(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg()) + for v in batch.values(): + self.assertIsInstance(v, jax.ShapeDtypeStruct) + + +class TestShouldPreventCseInRemat(unittest.TestCase): + """Tests for should_prevent_cse_in_remat.""" + + def _make_cfg(self, scan_layers=False, gradient_accumulation_steps=1, hardware="tpu"): + cfg = MagicMock() + cfg.scan_layers = scan_layers + cfg.gradient_accumulation_steps = gradient_accumulation_steps + cfg.hardware = hardware + return cfg + + def test_scan_layers_returns_false(self): + self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(self._make_cfg(scan_layers=True))) + + def test_gpu_with_grad_accum_returns_false(self): + cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="gpu") + self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(cfg)) + + def test_gpu_multiprocess_with_grad_accum_returns_false(self): + cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="gpu_multiprocess") + self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(cfg)) + + def test_tpu_with_grad_accum_returns_true(self): + cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="tpu") + self.assertTrue(maxtext_utils.should_prevent_cse_in_remat(cfg)) + + def test_default_case_returns_true(self): + self.assertTrue(maxtext_utils.should_prevent_cse_in_remat(self._make_cfg())) + + +class TestCalculateTokensTrainingPerDevice(unittest.TestCase): + """Tests for calculate_tokens_training_per_device.""" + + def test_basic_calculation(self): + cfg = MagicMock() + cfg.max_target_length = 128 + cfg.per_device_batch_size = 2 + cfg.gradient_accumulation_steps = 4 + result = maxtext_utils.calculate_tokens_training_per_device(cfg) + self.assertEqual(result, 128 * 2 * 4) + + +class TestCalculateIndexerMaskRatio(unittest.TestCase): + """Tests for calculate_indexer_mask_ratio.""" + + def test_half_topk(self): + # K=T/2: ratio=0.5, mask = 0.5 - 0.5*0.25 = 0.375 + result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=4, max_target_length=8) + self.assertAlmostEqual(result, 0.375, places=6) + + def test_full_topk_equals_dense(self): + # K=T: ratio=1, mask = 1 - 0.5 = 0.5 (same as causal) + result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=8, max_target_length=8) + self.assertAlmostEqual(result, 0.5, places=6) + + def test_small_topk(self): + # K=1, T=100: ratio=0.01, mask ≈ 0.01 - 0.5*0.0001 ≈ 0.00995 + result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=1, max_target_length=100) + expected = 0.01 - 0.5 * (0.01**2) + self.assertAlmostEqual(result, expected, places=8) + + +class TestCalculateFfnMatmulTflops(unittest.TestCase): + """Tests for calculate_ffn_mamtul_tflops_per_device.""" + + def _make_cfg(self, num_activations=2): + cfg = MagicMock() + cfg.per_device_batch_size = 1 + cfg.max_target_length = 64 + cfg.emb_dim = 512 + cfg.mlp_activations = ["silu"] * num_activations + return cfg + + def test_total_flops_positive(self): + result = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(self._make_cfg(), mlp_dim=2048) + self.assertGreater(result, 0) + + def test_scales_with_mlp_dim(self): + cfg = self._make_cfg() + small = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=1024) + large = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=4096) + self.assertGreater(large, small) + + def test_single_activation(self): + """With one activation, ffn1 uses 1x mlp_dim.""" + cfg = self._make_cfg(num_activations=1) + result = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=2048) + expected_ffn1 = 2 * 1 * 64 * 2048 * 512 * 1 + expected_ffn2 = 2 * 1 * 64 * 2048 * 512 + self.assertEqual(result, expected_ffn1 + expected_ffn2) + + +class TestGetDenseMoeLayers(unittest.TestCase): + """Tests for get_dense_moe_layers.""" + + def _make_cfg(self, decoder_block, num_decoder_layers=32, first_num_dense_layers=3, interleave_moe_layer_step=4): + cfg = MagicMock() + cfg.decoder_block = decoder_block + cfg.num_decoder_layers = num_decoder_layers + cfg.first_num_dense_layers = first_num_dense_layers + cfg.interleave_moe_layer_step = interleave_moe_layer_step + return cfg + + def test_deepseek_block(self): + cfg = self._make_cfg(DecoderBlockType.DEEPSEEK, num_decoder_layers=32, first_num_dense_layers=3) + dense, moe = maxtext_utils.get_dense_moe_layers(cfg) + self.assertEqual(dense, 3) + self.assertEqual(moe, 29) + + def test_llama4_block(self): + cfg = self._make_cfg(DecoderBlockType.LLAMA4, num_decoder_layers=16, interleave_moe_layer_step=4) + dense, moe = maxtext_utils.get_dense_moe_layers(cfg) + self.assertEqual(moe, 4) # 16 // 4 + self.assertEqual(dense, 12) # 16 - 4 + + def test_qwen3_next_block(self): + cfg = self._make_cfg(DecoderBlockType.QWEN3_NEXT, num_decoder_layers=8) + dense, moe = maxtext_utils.get_dense_moe_layers(cfg) + self.assertEqual(dense, 0) + self.assertEqual(moe, 8) + + def test_unsupported_block_raises(self): + cfg = self._make_cfg(DecoderBlockType.DEFAULT) + with self.assertRaises(ValueError): + maxtext_utils.get_dense_moe_layers(cfg) + + +class TestCalculatePrefillTflops(unittest.TestCase): + """Tests for calculate_prefill_tflops_per_device.""" + + def _make_cfg(self, num_query_heads=8, num_decoder_layers=2, head_dim=64): + cfg = MagicMock() + cfg.num_query_heads = num_query_heads + cfg.num_decoder_layers = num_decoder_layers + cfg.head_dim = head_dim + return cfg + + def test_returns_three_positive_values(self): + cfg = self._make_cfg() + total, lw, attn = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=1_000_000, prefill_length=128, config=cfg, log=False + ) + self.assertGreater(total, 0) + self.assertGreater(lw, 0) + self.assertGreater(attn, 0) + + def test_total_is_sum_of_parts(self): + cfg = self._make_cfg() + total, lw, attn = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=500_000, prefill_length=64, config=cfg, log=False + ) + self.assertAlmostEqual(total, lw + attn, places=10) + + def test_scales_with_prefill_length(self): + cfg = self._make_cfg() + _, _, attn_short = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=1_000_000, prefill_length=64, config=cfg, log=False + ) + _, _, attn_long = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=1_000_000, prefill_length=128, config=cfg, log=False + ) + # Attention scales quadratically with prefill length + self.assertGreater(attn_long, attn_short * 3) + + +class TestSetupTrainingState(unittest.TestCase): + """Tests for setup_training_state (thin wrapper over setup_initial_state).""" + + def setUp(self): + extra_args = get_decoupled_parallelism_overrides() + self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + quant = quantizations.configure_quantization(self.config) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX path not covered by this test.") + self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + + def test_setup_training_state_returns_train_state(self): + rng = jax.random.PRNGKey(0) + tx = optax.adam(learning_rate=0.001) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) + state, _, _, _ = maxtext_utils.setup_training_state(None, self.config, self.mesh, None, init_state_fn) + self.assertEqual(state.tx, tx) + self.assertNotEqual(state.opt_state, {}) + + +class TestGetLogicalAnnotations(unittest.TestCase): + """Tests for get_logical_annotations.""" + + def setUp(self): + extra_args = get_decoupled_parallelism_overrides() + self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + quant = quantizations.configure_quantization(self.config) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX path not covered by this test.") + self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + self.rng = jax.random.PRNGKey(0) + self.tx = optax.adam(learning_rate=0.001) + + def test_returns_partition_spec_tree(self): + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) + annotations = maxtext_utils.get_logical_annotations(self.config, self.mesh, init_state_fn) + # Result should be a pytree with PartitionSpec leaves + leaves = jax.tree_util.tree_leaves(annotations) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertIsInstance(leaf, PartitionSpec) + + +class TestSaveQuantizedCheckpoint(unittest.TestCase): + """Tests for save_quantized_checkpoint_if_configured.""" + + def test_raises_when_no_quantization(self): + cfg = MagicMock() + cfg.quantization = "" + with self.assertRaises(AssertionError): + maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={}) + + @unittest.mock.patch("maxtext.utils.maxtext_utils.checkpointing") + def test_skips_save_when_path_empty(self, mock_ckpt): + cfg = MagicMock() + cfg.quantization = "int8" + cfg.save_quantized_params_path = "" + maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={}) + mock_ckpt.save_params_to_path.assert_not_called() + + @unittest.mock.patch("maxtext.utils.maxtext_utils.checkpointing") + def test_calls_save_when_path_set(self, mock_ckpt): + cfg = MagicMock() + cfg.quantization = "int8" + cfg.save_quantized_params_path = "/tmp/quantized" + cfg.checkpoint_storage_use_ocdbt = True + cfg.checkpoint_storage_use_zarr3 = True + maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={"w": jnp.ones((2,))}) + mock_ckpt.save_params_to_path.assert_called_once() + + +class TestAddConfigToSummaryWriter(unittest.TestCase): + """Tests for add_config_to_summary_writer.""" + + def test_calls_add_text_for_each_key(self): + cfg = MagicMock() + cfg.get_keys.return_value = {"learning_rate": 0.001, "steps": 100} + mock_writer = MagicMock() + + with unittest.mock.patch("maxtext.utils.max_utils.add_text_to_summary_writer") as mock_add: + maxtext_utils.add_config_to_summary_writer(cfg, mock_writer) + # Should have been called once per config key (process_index==0 in tests) + if jax.process_index() == 0: + self.assertEqual(mock_add.call_count, 2) + + +class TestMaybeDumpJaxpr(unittest.TestCase): + """Tests for maybe_dump_jaxpr.""" + + def test_early_return_when_disabled(self): + cfg = MagicMock() + cfg.dump_jaxpr = False + # Should return immediately without calling any JAX tracing (no exception raised) + maxtext_utils.maybe_dump_jaxpr(cfg, p_train_step=None, train_step_inputs=None) + + +class TestPrintShardingsParams(unittest.TestCase): + """Tests for print_shardings_params — normalization branches.""" + + def setUp(self): + """Build a minimal mesh and sharded param for testing.""" + self.mesh = Mesh(np.array(jax.devices()), ("data",)) + + def _make_simple_params(self): + """Return (params, param_sharding, logical) without a .params attribute.""" + params = {"w": jnp.ones((4,))} + param_sharding = {"w": NamedSharding(self.mesh, PartitionSpec(None))} + logical = {"w": PartitionSpec(None)} + return params, param_sharding, logical + + def test_runs_without_error_dict_inputs(self): + """print_shardings_params should not raise with plain dict inputs.""" + params, param_sharding, logical = self._make_simple_params() + # Should complete without raising + maxtext_utils.print_shardings_params(params, param_sharding, logical) + + def test_runs_without_logical_annotations(self): + """logical_annotations=None should be handled (no logical column).""" + params, param_sharding, _ = self._make_simple_params() + maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py new file mode 100644 index 0000000000..d4c47c4fb6 --- /dev/null +++ b/tests/unit/model_creation_utils_test.py @@ -0,0 +1,326 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for model_creation_utils.py.""" + +import sys +import unittest +from unittest.mock import MagicMock, patch + +import jax +import flax.linen as nn +from flax import nnx +from jax.sharding import Mesh + +from maxtext.configs import pyconfig +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL +from maxtext.models import models +from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils +from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides + + +def _make_config(**kwargs): + """Returns a minimal pyconfig suitable for model-creation tests.""" + extra = get_decoupled_parallelism_overrides() + defaults = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "base_num_decoder_layers": 2, + "attention": "dot_product", + "max_target_length": 16, + "base_emb_dim": 256, + "base_num_query_heads": 2, + "base_num_kv_heads": 2, + "max_prefill_predict_length": 4, + } + defaults.update(kwargs) + return pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **defaults, + **extra, + ) + + +def _make_mesh(config): + devices_array = maxtext_utils.create_device_mesh(config) + return Mesh(devices_array, config.mesh_axes) + + +class TestGetTransformerModel(unittest.TestCase): + """Tests for get_transformer_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_returns_linen_module_when_rngs_is_none(self): + """Without rngs, should return a Linen nn.Module.""" + model = model_creation_utils.get_transformer_model(self.config, self.mesh, quant=None, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_returns_nnx_module_when_rngs_provided(self): + """With rngs, should return an NNX nnx.Module.""" + model = nnx.eval_shape( + lambda: model_creation_utils.get_transformer_model( + self.config, self.mesh, quant=None, rngs=nnx.Rngs(params=0, dropout=1, aqt=2) + ) + ) + self.assertIsInstance(model, nnx.Module) + + def test_respects_model_mode_prefill(self): + """Linen model created with MODEL_MODE_PREFILL should differ from train mode.""" + linen_train = model_creation_utils.get_transformer_model( + self.config, self.mesh, quant=None, model_mode=MODEL_MODE_TRAIN, rngs=None + ) + linen_prefill = model_creation_utils.get_transformer_model( + self.config, self.mesh, quant=None, model_mode=MODEL_MODE_PREFILL, rngs=None + ) + # Both are still nn.Module instances + self.assertIsInstance(linen_train, nn.Module) + self.assertIsInstance(linen_prefill, nn.Module) + + +class TestCreateModel(unittest.TestCase): + """Tests for create_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_returns_linen_model_without_rngs(self): + model = model_creation_utils.create_model(self.config, self.mesh) + self.assertIsInstance(model, nn.Module) + + def test_returns_nnx_model_with_rngs(self): + model = nnx.eval_shape( + lambda: model_creation_utils.create_model(self.config, self.mesh, rngs=nnx.Rngs(params=0, dropout=1, aqt=2)) + ) + self.assertIsInstance(model, nnx.Module) + + def test_model_mode_train_default(self): + """Default model_mode is MODEL_MODE_TRAIN.""" + model = model_creation_utils.create_model(self.config, self.mesh) + self.assertIsInstance(model, nn.Module) + + +class TestFromConfig(unittest.TestCase): + """Tests for from_config().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_linen_path_rngs_none(self): + """from_config with rngs=None should return a Linen nn.Module.""" + model = model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_nnx_path_with_rngs(self): + """from_config with rngs provided should return an NNX nnx.Module.""" + model = nnx.eval_shape( + lambda: model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=nnx.Rngs(params=0, dropout=1, aqt=2)) + ) + self.assertIsInstance(model, nnx.Module) + + def test_mesh_created_from_devices_when_none(self): + """from_config should work when mesh is None (creates mesh internally).""" + model = model_creation_utils.from_config(self.config, devices=None, mesh=None, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_model_mode_is_forwarded(self): + """from_config should accept and forward model_mode.""" + model = model_creation_utils.from_config(self.config, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_explicit_shard_mode_creates_mesh_with_explicit_axis_types(self): + """from_config with shard_mode=explicit should create mesh using AxisType.Explicit.""" + cfg = _make_config(shard_mode="explicit") + # Should not raise; mesh is built with AxisType.Explicit for each axis + model = model_creation_utils.from_config(cfg, mesh=None, rngs=None) + self.assertIsInstance(model, nn.Module) + + +class TestCreateNNXAbstractModel(unittest.TestCase): + """Tests for create_nnx_abstract_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_returns_tuple_of_callable_and_module(self): + create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) + self.assertTrue(callable(create_fn)) + self.assertIsInstance(abstract_model, nnx.Module) + + def test_abstract_model_has_abstract_arrays(self): + """Abstract model leaves should be ShapeDtypeStruct, not concrete arrays.""" + _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) + _, state = nnx.split(abstract_model) + leaves = jax.tree.leaves(state) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + # In abstract state, values are nnx.Variable wrapping abstract shapes/ShapeDtypeStruct + # Concrete jax.Array would have a .devices() method; abstract ones should not be Arrays + self.assertNotIsInstance(leaf, jax.Array) + + def test_create_fn_produces_concrete_model(self): + """The returned create_fn should produce a real (concrete) NNX Module.""" + create_fn, _ = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) + with self.mesh: + concrete = create_fn() + self.assertIsInstance(concrete, nnx.Module) + leaves = jax.tree.leaves(nnx.state(concrete)) + for leaf in leaves: + self.assertIsInstance(leaf, jax.Array) + + def test_works_without_explicit_mesh(self): + """create_nnx_abstract_model should work when mesh=None (from_config creates mesh).""" + create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=None) + self.assertTrue(callable(create_fn)) + self.assertIsInstance(abstract_model, nnx.Module) + + def test_explicit_rng_key_is_used(self): + """Passing a rng_key should not raise and returns valid abstract model.""" + rng_key = jax.random.PRNGKey(42) + create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model( + self.config, mesh=self.mesh, rng_key=rng_key + ) + self.assertTrue(callable(create_fn)) + self.assertIsInstance(abstract_model, nnx.Module) + + def test_prefill_model_mode(self): + """create_nnx_abstract_model should accept MODEL_MODE_PREFILL.""" + _, abstract_model = model_creation_utils.create_nnx_abstract_model( + self.config, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL + ) + self.assertIsInstance(abstract_model, nnx.Module) + + +class TestCreateNnxModel(unittest.TestCase): + """Tests for create_nnx_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_no_checkpoint_returns_model_and_mesh(self): + """Without load_parameters_path, should return (model, mesh) cleanly.""" + model, mesh = model_creation_utils.create_nnx_model(self.config, self.mesh) + self.assertIsInstance(model, models.Transformer) + self.assertIsInstance(mesh, Mesh) + + def test_mesh_none_uses_abstract_model_mesh(self): + """When mesh=None is passed, the function resolves it from the abstract model.""" + model, mesh = model_creation_utils.create_nnx_model(self.config, mesh=None) + self.assertIsInstance(model, models.Transformer) + self.assertIsInstance(mesh, Mesh) + + def test_explicit_rng_key(self): + """An explicit rng_key should be accepted without error.""" + rng_key = jax.random.PRNGKey(99) + model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, rng_key=rng_key) + self.assertIsInstance(model, models.Transformer) + + def test_inference_mode_disables_dropout_rng(self): + """MODEL_MODE_PREFILL should create rngs without a dropout key.""" + model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, model_mode=MODEL_MODE_PREFILL) + self.assertIsInstance(model, models.Transformer) + + def test_debug_sharding_flag(self): + """debug_sharding=True should execute the sharding-print path without error.""" + cfg = _make_config(debug_sharding=True) + model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + # ---- checkpoint loading: mocked paths ---- + + def _make_linen_metadata_mock(self): + """Mock ocp metadata that looks like a Linen checkpoint.""" + meta = MagicMock() + meta.item_metadata.tree.keys.return_value = ["params"] + meta.item_metadata.tree.get.return_value = {"params": {}} + return meta + + def _make_nnx_metadata_mock(self): + """Mock ocp metadata that looks like an NNX checkpoint.""" + meta = MagicMock() + meta.item_metadata.tree.keys.return_value = ["decoder"] + meta.item_metadata.tree.get.return_value = {} + return meta + + @patch("maxtext.utils.model_creation_utils.ocp") + def test_load_nnx_checkpoint(self, mock_ocp): + """NNX-format checkpoint: restored values are wrapped under a 'value' key.""" + _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, self.mesh) + _, abstract_state = nnx.split(abstract_model) + + # Build a fake restored dict with 'value' keys (NNX checkpoint structure) + fake_restored = jax.tree.map( + lambda v: {"value": v.value}, + abstract_state, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock() + mock_ckptr.restore.return_value = fake_restored + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + + cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/nnx_ckpt") + model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + @patch("maxtext.utils.model_creation_utils.ocp") + def test_load_linen_checkpoint(self, mock_ocp): + """Linen-format checkpoint: restored values are nested under 'params'/'params'.""" + _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, self.mesh) + _, abstract_state = nnx.split(abstract_model) + + # Build fake plain-value dict (Linen structure) + fake_params = jax.tree.map( + lambda v: v.value, + abstract_state, + is_leaf=lambda n: isinstance(n, nnx.Variable), + ) + fake_restored = {"params": {"params": fake_params}} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_linen_metadata_mock() + mock_ckptr.restore.return_value = fake_restored + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + + cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/linen_ckpt") + model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + @patch("maxtext.utils.model_creation_utils.ocp") + def test_checkpoint_load_error_raises_value_error(self, mock_ocp): + """Any exception during checkpoint loading should be re-raised as ValueError.""" + mock_ckptr = MagicMock() + mock_ckptr.metadata.side_effect = RuntimeError("disk on fire") + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + + cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/bad_ckpt") + with self.assertRaises(ValueError): + model_creation_utils.create_nnx_model(cfg, self.mesh) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 2cd696f241..c9e4deb725 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -14,6 +14,7 @@ """Compare expected sharding of models with actual sharding of models.""" +import functools import hashlib import json import os @@ -127,6 +128,9 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) f"model_name={model_name}", "log_config=false", "debug_sharding=true", # for input sharding dump + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] root_dir = "tests/utils/sharding_info" @@ -215,6 +219,9 @@ def abstract_state_and_shardings(request): f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", "weight_dtype=float32", + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] config = pyconfig.initialize(params) validate_config(config) @@ -228,13 +235,15 @@ def abstract_state_and_shardings(request): tx = optimizers.get_optimizer(config, learning_rate_schedule) rng = jax.random.PRNGKey(0) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # Get abstract state and physical shardings from maxtext_utils abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, rng, topology_mesh, is_training=True + config, topology_mesh, init_state_fn, is_training=True ) # Get logical shardings from maxtext_utils - logical_shardings = maxtext_utils.get_logical_annotations(model, tx, config, rng, topology_mesh, is_training=True) + logical_shardings = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings diff --git a/tests/unit/state_dtypes_test.py b/tests/unit/state_dtypes_test.py index 77e166193a..10db1bf199 100644 --- a/tests/unit/state_dtypes_test.py +++ b/tests/unit/state_dtypes_test.py @@ -13,6 +13,7 @@ # limitations under the License. """ Test that all weights are expected dtype (default float32) """ +from functools import partial import unittest import jax @@ -47,7 +48,12 @@ def get_state(self, argv): tx = optimizers.get_optimizer(config, learning_rate_schedule) _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, example_rng, mesh) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) return abstract_state def get_weights(self, argv): diff --git a/tests/unit/train_utils_test.py b/tests/unit/train_utils_test.py new file mode 100644 index 0000000000..a8b9458794 --- /dev/null +++ b/tests/unit/train_utils_test.py @@ -0,0 +1,196 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for train_utils.py.""" + +import unittest +from dataclasses import dataclass +from unittest.mock import MagicMock + +from maxtext.utils.train_utils import validate_train_config, create_training_optimizer + + +@dataclass +class MockConfig: + """Minimal mock config for validate_train_config tests.""" + + run_name: str = "test_run" + dataset_path: str = "gs://test-bucket/data" + base_output_directory: str = "gs://test-bucket/output" + steps: int = 100 + quantization: str = "" + gradient_accumulation_steps: int = 1 + packing: bool = False + dataset_type: str = "tfds" + + # Fields needed for create_training_optimizer + opt_type: str = "adamw" + adam_b1: float = 0.9 + adam_b2: float = 0.95 + adam_eps: float = 1e-8 + adam_eps_root: float = 0.0 + adam_weight_decay: float = 0.1 + mu_dtype: str = "" + learning_rate: float = 1e-4 + learning_rate_schedule_steps: int = 1000 + warmup_steps_fraction: float = 0.1 + cosine_learning_rate_final_fraction: float = 0.0 + steps: int = 100 + lr_schedule_type: str = "cosine" + use_iota_embed: bool = False + + +class TestValidateTrainConfig(unittest.TestCase): + """Tests for validate_train_config.""" + + def test_valid_config_passes(self): + """Verifies no exception raised for a valid config.""" + config = MockConfig() + # Should not raise + validate_train_config(config) + + def test_missing_run_name_raises(self): + """Verifies AssertionError when run_name is empty.""" + config = MockConfig(run_name="") + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_zero_steps_raises(self): + """Verifies AssertionError when steps is 0.""" + config = MockConfig(steps=0) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_negative_steps_raises(self): + """Verifies AssertionError when steps is negative.""" + config = MockConfig(steps=-5) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_fp8_with_grad_accumulation_raises(self): + """Verifies AssertionError for fp8 quantization + gradient_accumulation_steps > 1.""" + config = MockConfig(quantization="fp8", gradient_accumulation_steps=2) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_nanoo_fp8_with_grad_accumulation_raises(self): + """Verifies AssertionError for nanoo_fp8 quantization + gradient_accumulation_steps > 1.""" + config = MockConfig(quantization="nanoo_fp8", gradient_accumulation_steps=4) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_fp8_with_single_grad_accumulation_passes(self): + """Verifies no error for fp8 with gradient_accumulation_steps=1.""" + config = MockConfig(quantization="fp8", gradient_accumulation_steps=1) + validate_train_config(config) # Should not raise + + def test_packing_with_synthetic_data_logs_warning(self): + """Verifies no exception for packing + synthetic (just logs a warning).""" + config = MockConfig(packing=True, dataset_type="synthetic") + # Should not raise - just log a warning + validate_train_config(config) + + def test_local_dataset_path_logs_warning(self): + """Verifies no exception for local dataset_path (just logs a warning).""" + config = MockConfig(dataset_path="/local/path/to/data") + validate_train_config(config) # Should not raise + + def test_local_output_directory_logs_warning(self): + """Verifies no exception for local base_output_directory (just logs a warning).""" + config = MockConfig(base_output_directory="/local/output") + validate_train_config(config) # Should not raise + + +class TestCreateTrainingOptimizer(unittest.TestCase): + """Tests for create_training_optimizer.""" + + def _make_config(self, opt_type="adamw", **kwargs): + """Creates a mock config for optimizer tests.""" + cfg = MockConfig(opt_type=opt_type, **kwargs) + return cfg + + def _mock_lr_schedule(self): + """Returns a mock learning rate schedule that returns a fixed value.""" + return lambda step: 1e-4 + + def test_adamw_optimizer_returns_schedule_and_tx(self): + """Verifies create_training_optimizer returns a schedule and optax transform for adamw.""" + config = MagicMock() + config.opt_type = "adamw" + config.adam_b1 = 0.9 + config.adam_b2 = 0.999 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.01 + config.mu_dtype = None + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + schedule, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(schedule) + self.assertIsNotNone(tx) + # Verify it's an optax GradientTransformation + self.assertTrue(hasattr(tx, "init")) + self.assertTrue(hasattr(tx, "update")) + + def test_adam_pax_optimizer_returns_tx(self): + """Verifies create_training_optimizer works for adam_pax optimizer.""" + config = MagicMock() + config.opt_type = "adam_pax" + config.adam_b1 = 0.9 + config.adam_b2 = 0.999 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.01 + config.mu_dtype = None + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + _, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(tx) + self.assertTrue(hasattr(tx, "init")) + self.assertTrue(hasattr(tx, "update")) + + def test_sgd_optimizer_returns_tx(self): + """Verifies create_training_optimizer works for sgd optimizer.""" + config = MagicMock() + config.opt_type = "sgd" + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.0 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + _, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(tx) + self.assertTrue(hasattr(tx, "init")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index c6a8e37997..4dc3e094c9 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -37,6 +37,7 @@ """Check if the logits generated by a model's src/maxtext/HF implementation matches golden logits for the same inputs""" import argparse +import functools import os from pathlib import Path import sys @@ -242,8 +243,13 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, config, False, rng1) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) if test_args.golden_logits_path == "": input_golden_data_path = os.path.join( @@ -426,8 +432,13 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, rng1) + maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) prompts = ["I love to", "Today is a", "What is the"] all_data_to_save = [] diff --git a/tools/gcs_benchmarks/standalone_checkpointer.py b/tools/gcs_benchmarks/standalone_checkpointer.py index 6240c10cc0..9f39cc529f 100644 --- a/tools/gcs_benchmarks/standalone_checkpointer.py +++ b/tools/gcs_benchmarks/standalone_checkpointer.py @@ -19,6 +19,7 @@ # See github.com/google/maxtext/issues/20 for more import datetime +from functools import partial import os from typing import Sequence @@ -51,15 +52,21 @@ def checkpoint_loop(config, state=None): Returns: """ - model = from_config(config) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = from_config(config) mesh = model.mesh - init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools( - config, model, mesh - ) - - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, tx, config, init_rng, mesh, is_training=True - ) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + _, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) + + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() @@ -82,30 +89,24 @@ def checkpoint_loop(config, state=None): if state is not None: # Checkpoint was available for restore if jax.process_index() == 0: max_logging.log( - "STANDALONE CHECKPOINTER : Checkpoint restored in :" - f" {checkpoint_load_end - checkpoint_load_start}" + "STANDALONE CHECKPOINTER : Checkpoint restored in :" f" {checkpoint_load_end - checkpoint_load_start}" ) else: # Checkpoint was unavailable, state needs to be initialized - state, _, _, _ = maxtext_utils.setup_training_state( - model, None, tx, config, init_rng, mesh, checkpoint_manager - ) + state, _, _, _ = maxtext_utils.setup_training_state(None, config, mesh, checkpoint_manager, init_state_fn) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() # A barrier to sync all hosts before starting to save checkpoint - jax.experimental.multihost_utils.sync_global_devices( - "Barrier before save" - ) + jax.experimental.multihost_utils.sync_global_devices("Barrier before save") if checkpointing.save_checkpoint(checkpoint_manager, int(step), state): checkpoint_manager.wait_until_finished() end_time = datetime.datetime.now() if jax.process_index() == 0: max_logging.log( - "STANDALONE CHECKPOINTER : Checkpoint saved in" - f" {end_time - start_time} ,step {step}, on host 0" + "STANDALONE CHECKPOINTER : Checkpoint saved in" f" {end_time - start_time} ,step {step}, on host 0" ) return state @@ -123,12 +124,8 @@ def add_entropy_to_checkpoint(state): state: Returns state with entropy added to the optimizer state. """ opt_0 = state.opt_state[0] - opt_0 = opt_0._replace( - mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params) - ) - opt_0 = opt_0._replace( - nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params) - ) + opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params)) + opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params)) new_opt = [opt_0] + list(state.opt_state[1:]) state = state.replace(opt_state=new_opt) return state diff --git a/tools/gcs_benchmarks/standalone_dataloader.py b/tools/gcs_benchmarks/standalone_dataloader.py index 9766349aac..54177e9528 100644 --- a/tools/gcs_benchmarks/standalone_dataloader.py +++ b/tools/gcs_benchmarks/standalone_dataloader.py @@ -38,13 +38,13 @@ def data_load_loop(config, state=None): """Main data loader loop. Loads batches of data for each training step. """ - _, _, _, _, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) + _, _, _, model, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) data_loader = DataLoader(config, mesh, data_iterator, None) example_batch = None start = datetime.datetime.now() - start_step = get_first_step(state) + start_step = get_first_step(model, state) example_batch = data_loader.load_next_batch() jax.block_until_ready(example_batch) first_end = datetime.datetime.now() From 56c9cb73e81049350b471a36f5d68279afb97b32 Mon Sep 17 00:00:00 2001 From: Xibin Liu Date: Wed, 21 Jan 2026 00:46:10 +0000 Subject: [PATCH 2/3] NNX migration: NNX utils - Add utils to manipulate the NNX shardings with abstract state of a model - also add unit tests for the utils - Extract mesh creation function to maxtext_utils.get_mesh_from_config() - also add unit tests for this func Note: flax v0.12 has DeprecationWarning in multiple places: - DeprecationWarning: '.value' access is now deprecated. Use variable.get_value() or variable[...] (for [Array]). - DeprecationWarning: 'VariableState' was removed, this is just an alias to 'Variable'. Plase use 'Variable' directly instead. But since the code needs to work with post-training, which currently requires flax v0.11, we didn't change code for these warnings. --- src/maxtext/utils/maxtext_utils.py | 29 +++- src/maxtext/utils/model_creation_utils.py | 40 +++-- tests/unit/maxtext_utils_nnx_test.py | 182 ++++++++++++++++++++++ tests/unit/maxtext_utils_test.py | 95 +++++++---- 4 files changed, 289 insertions(+), 57 deletions(-) create mode 100644 tests/unit/maxtext_utils_nnx_test.py diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 35ba11c389..4a1335d732 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -18,6 +18,7 @@ import functools import pickle import os +from typing import Sequence from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -27,6 +28,7 @@ from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load +from jax.sharding import AxisType, Mesh import jax import jax.numpy as jnp @@ -36,7 +38,8 @@ import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.configs import pyconfig +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode from maxtext.configs import types from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing @@ -1521,3 +1524,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging all_host_upload=False, # Only upload from lead host (Host 0) ) + + +def get_mesh_from_config( + config: pyconfig.HyperParameters, + devices: Sequence[jax.Device] | None = None, +) -> Mesh: + """ + Geh mesh from the configuration. + + Args: + config: the configuration + devices: the devices + + Returns: + the device mesh + """ + devices_array = create_device_mesh(config, devices) + + if config.shard_mode == ShardMode.EXPLICIT: + axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) + else: + axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) + + return Mesh(devices_array, config.mesh_axes, axis_types=axis_types) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index a010aac993..805d64bb21 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -18,18 +18,16 @@ from collections.abc import Sequence from functools import partial from typing import overload - from etils import epath from flax import nnx import flax.linen as nn import jax -from jax.sharding import AxisType, Mesh +from jax.sharding import Mesh from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import max_utils -from maxtext.utils import maxtext_utils +from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx from orbax import checkpoint as ocp @@ -40,6 +38,7 @@ def from_config( mesh: Mesh | None = None, *, model_mode: str = MODEL_MODE_TRAIN, + rngs: None = None, ) -> nn.Module: ... @@ -80,15 +79,7 @@ def from_config( model = from_config(config) """ if mesh is None: - devices_array = maxtext_utils.create_device_mesh(config, devices) - - if config.shard_mode == ShardMode.EXPLICIT: - axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) - else: - axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) - - mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types) - + mesh = maxtext_utils.get_mesh_from_config(config, devices) model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) # Return only the model @@ -146,16 +137,10 @@ def _create_model(rng_key=None): def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + is_training = model_mode == MODEL_MODE_TRAIN def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) - - if model_mode == MODEL_MODE_TRAIN: - rngs = nnx.Rngs(params=rng_key, dropout=1) - else: - rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference - + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) @@ -168,6 +153,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, if mesh is None: mesh = abstract_model.mesh + # Note for pure_nnx: + # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and + # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen + # LogicallyPartitioned structure. + # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned + # structure in the abstract state and we can get the sharded state with the following code: + # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) + # abstract_model = nnx.merge(graphdef, state) + # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) + # sharded_state = nnx.state(model) + # JIT a function that creates the model state with proper sharding from the start. # By providing out_shardings, we instruct JAX to produce sharded output directly, # avoiding a large intermediate allocation on a single device. diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py new file mode 100644 index 0000000000..0eb1f7ef77 --- /dev/null +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -0,0 +1,182 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for the common MaxText NNX utilities """ +import unittest +from dataclasses import dataclass +from typing import Any +import jax +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental import mesh_utils + +from maxtext.utils import maxtext_utils_nnx + + +class TestMaxTextUtilsNNX(unittest.TestCase): + """Test the functions for MaxText Utils.""" + + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" + + init_weights_seed: int = 42 + + class TinyModel(nnx.Module): + """ + A tiny NNX model with logical annotations. + Annotations are required to test that sharding extraction logic works. + """ + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + jax.device_count(), + jax.device_count(), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("data", None)), + # FIX: Removed () from zeros. zeros is the initializer function itself, + # not a factory like lecun_normal(). + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("data",)), + rngs=rngs, + ) + + def tiny_model_init_fn(self): + """Factory function for model initialization.""" + return self.TinyModel(rngs=nnx.Rngs(0)) + + def setUp(self): + # Create a mesh for sharding tests. + # NamedSharding requires an active Mesh to resolve logical names. + self.devices = mesh_utils.create_device_mesh((jax.device_count(),)) + self.mesh = Mesh(self.devices, axis_names=("data",)) + + def test_create_nnx_rngs_training(self): + # Using Any to satisfy static type checkers for the MockConfig + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=True) + + self.assertIsInstance(rngs, nnx.Rngs) + # FIX: nnx.Rngs does not have a .streams attribute. + # Check for stream attributes directly on the object. + self.assertTrue(hasattr(rngs, "params")) + self.assertTrue(hasattr(rngs, "dropout")) + self.assertTrue(hasattr(rngs, "aqt")) + + def test_create_nnx_rngs_inference(self): + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=False) + + self.assertIsInstance(rngs, nnx.Rngs) + # Check that 'params' exists but 'dropout' and 'aqt' were excluded + self.assertTrue(hasattr(rngs, "params")) + self.assertFalse(hasattr(rngs, "dropout")) + self.assertFalse(hasattr(rngs, "aqt")) + + def test_move_memory(self): + sharding = NamedSharding(self.mesh, P("data")) + self.assertNotEqual(sharding.memory_kind, "pinned_host") + + path = ("layers", "linear", "kernel") + host_sharding = maxtext_utils_nnx.move_memory_to_host(path, sharding) + + self.assertEqual(host_sharding.memory_kind, "pinned_host") + self.assertEqual(host_sharding.spec, P("data")) + + device_sharding = maxtext_utils_nnx.move_memory_to_device(path, sharding) + + self.assertEqual(device_sharding.memory_kind, "device") + self.assertEqual(device_sharding.spec, P("data")) + + def test_get_set_named_sharding_nnx(self): + # 1. Create the abstract state using standard NNX functional API + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + + # 2. Test extraction + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # Verify kernel and bias match the P("data") annotations from TinyModel + self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) + self.assertEqual(extracted_shardings.linear.bias.get_value().spec, P("data")) + + # Target kernel spec update + new_kernel_spec = P(None, "data") + + def update_spec_fn(path, leaf_sharding): + path_str = jax.tree_util.keystr(path) + if "linear" in path_str and "kernel" in path_str: + # Construct a new NamedSharding with the requested logical spec + return NamedSharding(leaf_sharding.mesh, new_kernel_spec) + return leaf_sharding + + # Apply the spec change to the extracted sharding tree + extracted_shardings = jax.tree.map_with_path(update_spec_fn, extracted_shardings) + + # 3. Test setting new shardings + # Transform the extracted shardings to host memory + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + updated_abstract = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, new_shardings) + + # Verify the metadata inside the abstract state leaf has updated its sharding + self.assertEqual(updated_abstract.linear.kernel.sharding.memory_kind, "pinned_host") + # Also verify the spec was updated successfully + self.assertEqual(updated_abstract.linear.kernel.sharding.spec, new_kernel_spec) + + # 4. Verify named sharding is preserved after NNX merge (update) and split (state) + model = self.tiny_model_init_fn() + nnx.update(model, updated_abstract) + re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) + + # Verify kernel and bias have expected sharding + self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) + self.assertEqual(re_extracted_shardings.linear.bias.get_value().spec, P("data")) + + def test_create_nnx_sharded_model(self): + # 1. Create abstract model + graphdef, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + abstract_model = nnx.merge(graphdef, abstract_state) + + # 2. Modify shardings to trigger host offloading + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + + # 3. Run the sharded creation + # We pass the abstract model and use the custom sharding for instantiation + sharded_model = maxtext_utils_nnx.create_nnx_sharded_model( + abstract_model, self.tiny_model_init_fn, mesh=self.mesh, named_sharding=new_shardings + ) + + # 4. Verify the model is concrete (contains Arrays) and sharded on host + self.assertIsInstance(sharded_model.linear.kernel[...], jax.Array) + self.assertEqual(sharded_model.linear.kernel[...].sharding.memory_kind, "pinned_host") + + def test_get_partition_spec_nnx(self): + """Verifies extraction of PartitionSpecs from NamedShardings.""" + # 1. Create abstract state and get sharding + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # 2. Execute extraction + spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) + + # 3. Verify that the leaves are now raw PartitionSpecs + # Expected values derived from TinyModel definition + expected_spec_k = P("data", None) + expected_spec_b = P("data") + + self.assertEqual(spec["linear"]["kernel"], expected_spec_k) + self.assertEqual(spec["linear"]["bias"], expected_spec_b) + self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 4850e972b3..7a09750a86 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,10 +15,11 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any +from typing import Any, Sequence from collections.abc import Callable import unittest -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch +from dataclasses import dataclass, field from flax import linen as nn from flax import nnx @@ -27,9 +28,9 @@ import jax from jax import random, vmap import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec from maxtext.configs import pyconfig -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models @@ -922,39 +923,65 @@ def test_wsd_schedule(self): self.assertIn("wsd_decay_steps_fraction", str(cm.exception)) -class TestGetAbstractState(unittest.TestCase): - """Test class for get_abstract_state.""" +class TestMeshUtils(unittest.TestCase): + """Test suite for the mesh creation utility function.""" - def setUp(self): - extra_args = get_decoupled_parallelism_overrides() - self.config = pyconfig.initialize( - [None, get_test_config_path()], - **extra_args, - enable_checkpointing=False, - model_name="llama3.1-8b", - per_device_batch_size=1, - max_target_length=16, - ) - devices_array = maxtext_utils.create_device_mesh(self.config) - self.mesh = Mesh(devices_array, self.config.mesh_axes) - quant = quantizations.configure_quantization(self.config) - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - self.rng = jax.random.PRNGKey(0) - self.tx = optax.adam(learning_rate=0.001) + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" - def test_get_abstract_state(self): - """Tests that get_abstract_state returns abstract arrays.""" - # get_abstract_state returns a tuple, the first element is the abstract state. - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) - abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self.mesh, init_state_fn) - - # Check that params are abstract - param_leaves = jax.tree_util.tree_leaves(abstract_state.params) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in param_leaves)) + init_weights_seed: int = 42 + shard_mode: str = ShardMode.EXPLICIT + mesh_axes: Sequence[str] = field(default_factory=lambda: ["data", "model"]) - # Check that opt_state is abstract - opt_state_leaves = jax.tree_util.tree_leaves(abstract_state.opt_state) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves)) + def setUp(self): + # Setup a dummy device array for the mock to return + self.devices_array = np.array(jax.devices()) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_explicit_mode(self, mock_create_device_mesh): + """Tests that ShardMode.EXPLICIT sets axis_types to MANUAL.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:1].reshape((1,)) + config = self.MockConfig(shard_mode=ShardMode.EXPLICIT, mesh_axes=["data"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + # Check that the internal utility was called correctly + mock_create_device_mesh.assert_called_once_with(config, None) + + # Verify Mesh properties + self.assertEqual(mesh.axis_names, ("data",)) + # In JAX, AxisType.MANUAL is the equivalent for explicit control + self.assertEqual(mesh.axis_types, (AxisType.Explicit,)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_auto_mode(self, mock_create_device_mesh): + """Tests that ShardMode.AUTO sets axis_types to AUTO.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:2].reshape((2, 1)) + config = self.MockConfig(shard_mode=ShardMode.AUTO, mesh_axes=["data", "model"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + self.assertEqual(len(mesh.axis_types), 2) + self.assertTrue(all(t == AxisType.Auto for t in mesh.axis_types)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_with_provided_devices(self, mock_create_device_mesh): + """Tests that provided devices are passed through to the mesh creator.""" + config = self.MockConfig() + specific_devices = self.devices_array[:2].reshape((1, 2)) + mock_create_device_mesh.return_value = specific_devices + + _ = maxtext_utils.get_mesh_from_config(config, devices=specific_devices) + + # Verify the second argument to create_device_mesh was our device list + mock_create_device_mesh.assert_called_once_with(config, specific_devices) class TestGetFunctionalTrainWithSignature(unittest.TestCase): From 0e0a8cd2611825290eb2a9c9ff2f7123d1190c28 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 25 Mar 2026 20:02:57 +0000 Subject: [PATCH 3/3] NNX: add TrainState, model creation utilities, and training loop support - Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch --- src/maxtext/common/checkpointing.py | 32 +- src/maxtext/trainers/pre_train/train.py | 423 +++++++++++------- src/maxtext/utils/gradient_accumulation.py | 35 +- src/maxtext/utils/maxtext_utils.py | 211 ++++++++- src/maxtext/utils/model_creation_utils.py | 130 ++++-- src/maxtext/utils/muon_utils.py | 60 ++- src/maxtext/utils/sharding.py | 121 ++++- src/maxtext/utils/train_utils.py | 54 ++- tests/unit/maxtext_utils_test.py | 185 +++++++- tests/unit/optimizers_test.py | 117 ++++- tests/unit/train_state_nnx_checkpoint_test.py | 291 ++++++++++++ tests/unit/train_state_nnx_test.py | 90 ++++ 12 files changed, 1448 insertions(+), 301 deletions(-) create mode 100644 tests/unit/train_state_nnx_checkpoint_test.py create mode 100644 tests/unit/train_state_nnx_test.py diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index cdfde92d50..f9b5af575c 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags import datetime from etils import epath +from flax import nnx from flax.training import train_state import jax from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE @@ -532,7 +533,7 @@ def load_state_if_possible( load_parameters_from_path: str, load_full_state_from_path: str, checkpoint_storage_concurrent_gb: int, - abstract_unboxed_pre_state: train_state.TrainState, + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, enable_single_replica_ckpt_restoring: bool | None = False, dataset_type: str | None = "tfds", step: int = -1, # -1 means latest @@ -600,8 +601,13 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX + restore_target = abstract_unboxed_pre_state + if isinstance(abstract_unboxed_pre_state, nnx.State): + restore_target = abstract_unboxed_pre_state.to_pure_dict() + + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) + checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args) match (checkpoint_manager, dataset_type, data_iterator): # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager @@ -636,9 +642,14 @@ def map_to_pspec(data): return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) if load_parameters_from_path != "": + if isinstance(abstract_unboxed_pre_state, nnx.State): + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + else: + params = abstract_unboxed_pre_state.params + restored_params = load_params_from_path( load_parameters_from_path, - abstract_unboxed_pre_state.params, + params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -730,7 +741,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Determine the effective step for saving a checkpoint. # If 'step' is not provided, this call is for a potential final checkpoint # and use the last completed step from the state. - actual_step = (int(state.step) - 1) if step is None else int(step) + if step is not None: + actual_step = int(step) + else: + if config.pure_nnx: + actual_step = int(state.optimizer.step) - 1 + else: + # Linen TrainState has .step attribute + actual_step = int(state.step) - 1 + + if config.pure_nnx: + # Convert nnx.State to dict. + state = state.to_pure_dict() # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index c2f32f076c..4b79c83ecb 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -34,8 +34,9 @@ import jax import jax.numpy as jnp +from jax.sharding import NamedSharding -from flax import linen as nn +from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig @@ -66,6 +67,7 @@ from maxtext.utils import maxtext_utils from maxtext.utils import qk_clip_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss @@ -90,11 +92,11 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): """loss_fn for both train and eval. Args: - model: A nn.Module + model: A nn.Module (Linen) or nnx.Module (NNX). config: Config of parameters data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. + params: Model params (Linen); unused for NNX (params are part of the model). is_train: True for train_step and False for eval_step Returns: @@ -173,7 +175,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): total_loss = jnp.sum(xent) total_z_loss = jnp.sum(z_loss) else: - # Flax NNX model + # Flax NNX model: logits = model( decoder_input_tokens=data["inputs"], decoder_positions=data["inputs_position"], @@ -184,7 +186,11 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], ) - intermediate_outputs = {} + # Capture NNX intermediates (MoE losses, hidden states, etc.) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + + if config.num_vocab_tiling > 1: + raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") if (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. @@ -296,62 +302,98 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): return loss, aux -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """ +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng=None): + """Training step for both Linen and NNX models. Args: - model: A nn.Module - state: A pytree of the current state of the model - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout + model: A nn.Module (Linen) or nnx.GraphDef of the TrainStateNNX (NNX). + config: Hyperparameters. + state_mesh_shardings: PyTree of PartitionSpecs for the train state. + params_shardings: PyTree of PartitionSpecs for model parameters, used for gradient accumulation. + state: Linen TrainState or NNX pure State. + data: Training data batch. + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. Returns: - new_state: Same format as state. + new_state: Updated Linen TrainState or NNX pure State. metrics: Dictionary of model metrics such as loss, training rate, etc. - rng2: A new rng key that can be used in future calls. - """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - params = state.params + # --- Per-path initialization --- + if isinstance(model, nn.Module): + reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + params = state.params + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args + else: + if config.use_dpo: + raise NotImplementedError("DPO for NNX modules has not been implemented.") + state = nnx.merge(model, state) # reconstruct TrainStateNNX + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + # --- Gradient computation --- if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + ga_fn, config, - model, - params, + ga_model, + ga_params, params_shardings, data, - dropout_rng, - extra_dpo_args, + ga_rng, + ga_dpo, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: + if isinstance(model, nn.Module): + if config.optimizer_memory_host_offload and config.use_dpo: reference_params = jax.device_put( reference_params, max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] - if config.shard_optimizer_over_data: - params = jax.tree.map( - functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), - params, - params_shardings, - ) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + if config.shard_optimizer_over_data: + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + else: + model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + if config.parameter_memory_host_offload: + # Params are kept on host (pinned_host) in in_shardings. Move only Param + # variables to device before the forward/backward pass so that all dot_general + # operands share the same memory space (XLA on GPU requires this). + # Using params_shardings (Param-only) avoids Shardy rank mismatches that + # occur when applying PartitionSpec() (rank-0 in SDY) to rank-1 RNG key tensors. + device_param_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + params_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + curr_params = jax.device_put(curr_params, device_param_shardings) + nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + if config.shard_optimizer_over_data: + curr_params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + curr_params, + params_shardings, + ) + nnx.update(state.model, curr_params) + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(model_graphdef, param, rest, copy=True) + loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -362,6 +404,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat raw_grads, max_utils.with_memory_kind(params_shardings, "device"), ) + + # Extract aux fields into locals intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] @@ -370,43 +414,65 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat moe_bias_updates = aux["moe_bias_updates"] mtp_loss = aux["mtp_loss"] - if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + if isinstance(model, nn.Module): + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + state = state.replace( + opt_state=jax.device_put( + state.opt_state, + jax.tree_util.tree_map( + lambda x: x.with_memory_kind(kind="device"), + state_mesh_shardings.opt_state, + ), + ) + ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + + def move(path, value): + max_logging.log(f"train.py: Moving f{path} to device") + return value.with_memory_kind(kind="device") + + state = state.replace( + params=jax.device_put( + state.params, + jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + ) + ) + new_state = state.apply_gradients(grads=grads) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Updates the shape to be aligned with state. + moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) else: grads = raw_grads - if config.optimizer_memory_host_offload: - state = state.replace( - opt_state=jax.device_put( - state.opt_state, - jax.tree_util.tree_map( - lambda x: x.with_memory_kind(kind="device"), - state_mesh_shardings.opt_state, - ), - ) - ) - # Move all parameters to device before optimizer update - if config.parameter_memory_host_offload: - max_logging.log("\nMoving all parameters to device before optimizer update") - - def move(path, value): - max_logging.log(f"train.py: Moving f{path} to device") - return value.with_memory_kind(kind="device") - - state = state.replace( - params=jax.device_put( - state.params, - jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), - ) - ) - new_state = state.apply_gradients(grads=grads) + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + if config.optimizer_memory_host_offload: + # state.optimizer is an NNX Optimizer module; state_mesh_shardings.optimizer + # is an NNX State. Use nnx.state() to get a compatible State for device_put. + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + new_opt_state = jax.device_put(opt_state, device_opt_shardings) + nnx.update(state.optimizer, new_opt_state) + state.apply_gradients(grads) + new_state = state - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Flax 'sow' returns a tuple, so we take the first element [0]. - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias + target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() scalar_metrics = { "learning/loss": loss, @@ -417,8 +483,9 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip - new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + # Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX) + if isinstance(model, nn.Module): + new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) @@ -428,34 +495,41 @@ def move(path, value): if not config.optimizer_memory_host_offload: scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + if isinstance(model, nn.Module): + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + else: + _, model_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) if config.use_dpo: scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] metrics = { "scalar": scalar_metrics, "scalars": {}, } - if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - - return new_state, metrics + if isinstance(model, nn.Module): + if config.use_dpo: + new_state = _merge_dpo_state(new_state, reference_params) + return new_state, metrics + return nnx.state(new_state), metrics -def eval_step(model, config, state, data, dropout_rng): +def eval_step(model, config, state, data, dropout_rng=None): """eval_step no backprop and new state compared with train_step.""" - - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(state.params, *extra_dpo_args) + if isinstance(model, nn.Module): + reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(state.params, *extra_dpo_args) + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -479,7 +553,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: + if isinstance(model, nn.Module) and config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -501,32 +575,46 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if isinstance(model, nn.Module): + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + jit_model = model + else: + if config.use_dpo: + raise NotImplementedError("DPO is not supported for NNX models.") + jit_model, state = nnx.split(state) params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + jit_model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng)) + elif config.shard_optimizer_over_data: + # NNX: reshard state so params match the data-sharded in_shardings (Zero-1 layout) + state = jax.device_put(state, state_mesh_shardings) + if isinstance(model, nn.Module): + lower_args = (state, shaped_batch, init_rng) + else: + lower_args = (state, shaped_batch) + maxtext_utils.maybe_dump_jaxpr(config, p_train_step, lower_args) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(*lower_args).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -535,7 +623,11 @@ def train_loop(config, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) + if isinstance(model, nn.Module): + setup_params = state.params + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(setup_params) _job_completed_gracefully = False try: @@ -545,57 +637,60 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + if isinstance(model, nn.Module): + # pylint: disable=not-callable + step_rng_args = (jax.jit(jax.random.fold_in)(init_rng, step),) + else: + step_rng_args = () with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + state, metrics = p_train_step(state, example_batch, *step_rng_args) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e4cad14906..3794e7fe93 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding +from flax import nnx from maxtext.common.common_types import ShardMode from maxtext.utils.sharding import maybe_shard_with_name @@ -49,7 +50,8 @@ def gradient_accumulation_loss_and_grad( config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. - params: The model parameters (PyTree). + params: The model parameters (PyTree). This is only used for Linen. For NNX, + we can get the params from the model. params_shardings: The sharding constraints for the parameters (PyTree). data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). @@ -67,12 +69,18 @@ def _maybe_shard_with_name(inputs, sharding_names): """Wrapper of maybe_shard_with_name with fixed shard_mode""" return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) + is_nnx = isinstance(model, nnx.Module) + # For more efficient DP/ZeRO-1 + GA if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) else: ga_params_shardings = grad_shardings = params_shardings + + if is_nnx: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop if config.shard_optimizer_over_data: @@ -87,11 +95,27 @@ def convert_to_bf16(param): ga_params = params ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + if is_nnx: + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + else: + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + if is_nnx: + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + acc_grad_and_loss["rest_state"] = next_rest_state + else: + rng = ( + jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32)) + if dropout_rng is not None + else None + ) + (_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["total_loss"] acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] @@ -119,6 +143,8 @@ def reshape_to_microbatch_accumulations(batch_arr): "mtp_loss": 0.0, "ga_params": ga_params, } + if is_nnx: + init_grad_and_loss["rest_state"] = rest grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -134,6 +160,9 @@ def reshape_to_microbatch_accumulations(batch_arr): raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + if is_nnx: + nnx.update(model, grad_and_loss["rest_state"]) + return loss, aux, raw_grads diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 4a1335d732..e40dd619e7 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,21 +20,20 @@ import os from typing import Sequence -from flax import linen as nn +from flax import nnx, linen as nn +from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning -from flax.training import train_state +from flax.training.train_state import TrainState import numpy as np -from jax.experimental import mesh_utils -from jax.experimental.serialize_executable import deserialize_and_load -from jax.sharding import AxisType, Mesh - import jax import jax.numpy as jnp +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from jax.experimental.serialize_executable import deserialize_and_load import optax - import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager @@ -48,6 +47,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -95,7 +95,10 @@ def get_functional_train_with_signature( """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) functional_train.__name__ = "train_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. @@ -106,7 +109,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -994,15 +1000,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> train_state.TrainState: +def init_decode_state(apply_fn, params) -> TrainState: """Init train state with null opt state for decode.""" - state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1124,7 +1130,7 @@ def setup_initial_state( is_training: True to initialize training state, False for decode state Returns: - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance state_mesh_annotations: the mesh annotations for the train state """ @@ -1163,19 +1169,35 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] + + # For NNX, convert the pure dict to nnx.State using the abstract state as template + if config.pure_nnx: + nnx.replace_by_pure_dict(unboxed_abstract_state, state) + state = unboxed_abstract_state else: init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() + if config.pure_nnx: + state = jax.jit( + lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params=raw_params) - - state = max_utils.unbox_logicallypartioned(state) + if config.pure_nnx: + # raw_params should have the same sharding info as in the model + nnx.update(state.model, raw_params) + else: + state = state.replace(params=raw_params) + if not config.pure_nnx: + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator @@ -1191,6 +1213,9 @@ def get_logical_annotations(config, mesh, init_state_fn): def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" + if config.pure_nnx: + return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -1234,6 +1259,148 @@ def move(path, x): ) +def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. + """ + + def _make_named_sharding(v): + val = v.get_value() + if not hasattr(val, "shape"): + # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve + # as-is so the treedef matches abs_var_state in the downstream jax.tree.map. + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = PartitionSpec() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the + # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits + # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode + # scan_axis=0 for non-Param types. stacked_rest non-Param variables have + # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames + # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted + # the scan axis. Only insert if not already present. + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + rules = composite_rules(context_rules, local_rules) + pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules)) + else: + pspec = PartitionSpec(*out_sharding) + return v.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): + """Calculates the abstract sharded state and memory placement for an NNX TrainState. + + This function performs an abstract trace of the NNX model and optimizer using + `nnx.get_abstract_model`. It resolves logical sharding annotations into physical + JAX shardings and applies memory placement optimizations such as optimizer + sharding and host memory offloading (pinning to CPU RAM). + + Args: + config: Configuration object containing sharding and offloading hyperparameters + (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). + mesh: JAX physical mesh used to resolve logical axis names to physical devices. + nnx_init_trainstate_fn: A zero-argument factory function that produces a + TrainStateNNX instance during the abstract trace. + is_training: Boolean indicating if the state is for training. If True, + optimizer state is processed and memory offloading strategies are applied. + + Returns: + A tuple containing (abstract_sharded_state, None, state_mesh_shardings): + abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with + fully resolved physical sharding and memory_kind metadata. + state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec + objects corresponding to each parameter/variable. + state_mesh_shardings: An nnx.State tree consisting of the raw JAX + Sharding objects corresponding to each parameter/variable. + """ + assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." + + with nn_partitioning.axis_rules(config.logical_axis_rules): + # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply + # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally + # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, + # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. + # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls + # var.shape for every variable when a global mesh is active, but masked optimizer + # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not + # needed here. + abs_model = nnx.eval_shape(nnx_init_trainstate_fn) + _, abs_var_state = nnx.split(abs_model) + named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + + state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + if is_training and config.shard_optimizer_over_data: + # Add data to sharding for optimizer state + optimizer_sharding = jax.tree_util.tree_map_with_path( + functools.partial(sharding.add_data_to_sharding, mesh), + abstract_state.optimizer, + state_mesh_shardings.optimizer, + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.optimizer_memory_host_offload: + optimizer_sharding = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.parameter_memory_host_offload: + assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." + _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) + state_params = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_params, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + nnx.update(state_mesh_shardings, state_params) + + abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) + state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + return ( + abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + + def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 805d64bb21..9aefea56b1 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,8 +30,8 @@ """ Utils that are only interesting for creating a model in MaxText. """ from collections.abc import Sequence +from typing import Callable, overload from functools import partial -from typing import overload from etils import epath from flax import nnx import flax.linen as nn @@ -103,47 +117,62 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Returns (_create_model_partial, abstract_model) for AOT compilation. +def get_nnx_create_model_fn(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None) -> Callable: - Unlike create_nnx_model, this does not shard parameters or load checkpoints. - It only builds the abstract shape/dtype structure needed by get_abstract_state - and optimizer construction (e.g. Muon). + def _create_model(): + is_training = model_mode == MODEL_MODE_TRAIN + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) + return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - Args: - config: the configuration - mesh: the device mesh - model_mode: train or inference - rng_key: optional RNG key + return _create_model - Returns: - (_create_model_partial, abstract_model) where _create_model_partial() creates - a concrete model instance and abstract_model is the eval_shape result. - """ - def _create_model(rng_key=None): - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) - rngs = nnx.Rngs(params=rng_key, dropout=1) - return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) +def create_nnx_abstract_model( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None +) -> tuple[Callable, nnx.Module]: + """Creates an abstract NNX model. - _create_model_partial = partial(_create_model, rng_key=rng_key) + Returns: + A tuple containing (create_model_fn, abstract_model): + create_model_fn: A zero-argument callable that produces a new model instance. + abstract_model: The stateful NNX model instance in an abstract state. + """ with nn.logical_axis_rules(config.logical_axis_rules): - abstract_model = nnx.eval_shape(_create_model_partial) - - return _create_model_partial, abstract_model - - -def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" - is_training = model_mode == MODEL_MODE_TRAIN - - def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) - return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - - _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) + _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) + if mesh is None: + _tmp = nnx.eval_shape(_create_model) + mesh = _tmp.mesh + # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of + # nnx.get_abstract_model, which uses get_var_pspec internally and ignores + # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, + # causing the stacked layers axis to be missing from the PartitionSpec. + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(_create_model) + graphdef, abs_var_state = nnx.split(abs_model) + named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + return _create_model, nnx.merge(graphdef, abstract_state) + + +def create_nnx_sharded_model_hybrid(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a sharded model for hybrid NNX modules containing Linen sub-modules. + + DEPRECATED: This function is a transitional utility for the Linen-to-NNX + migration. It should be removed once all model components are ported to + pure NNX modules. + + This function specifically handles the complexity of "mixed" state initialization, + where logical sharding annotations must be resolved for both NNX native + Parameters and legacy Linen variables wrapped via the NNX-Linen bridge. + It ensures that both systems correctly respect the provided mesh and + logical axis rules during the abstraction/sharding planning phase. + """ + _create_model_partial = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): abstract_model = nnx.eval_shape(_create_model_partial) @@ -153,17 +182,6 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, if mesh is None: mesh = abstract_model.mesh - # Note for pure_nnx: - # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and - # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen - # LogicallyPartitioned structure. - # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned - # structure in the abstract state and we can get the sharded state with the following code: - # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) - # abstract_model = nnx.merge(graphdef, state) - # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) - # sharded_state = nnx.state(model) - # JIT a function that creates the model state with proper sharding from the start. # By providing out_shardings, we instruct JAX to produce sharded output directly, # avoiding a large intermediate allocation on a single device. @@ -191,6 +209,26 @@ def create_sharded_state(): mesh=model.mesh, logical_annotations=specs, ) + maxtext_utils.print_shardings_params(sharded_state, out_shardings, model.mesh) + return model + + +def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + + if config.pure_nnx: + _create_model, abstract_model = create_nnx_abstract_model(config, mesh, devices, model_mode, rng_key) + model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh) + # TODO: print debug_sharding info + else: + model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key) + + sharded_state = nnx.state(model) + + if mesh is None: + mesh = model.mesh + + with mesh: if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index f50acd269f..45b3dabdfc 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -24,25 +24,23 @@ python3 -m MaxText.muon_utils qwen3-4b True """ - import os import sys from typing import Optional, Tuple import flax.linen as nn +from flax import nnx import jax from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from optax.contrib._muon import MuonDimensionNumbers as mdn -Transformer = models.transformer_as_linen - - def _is_path_contain_any(tuples, path): + """Checks if any element in 'tuples' is present in 'path'.""" return any(x in path for x in tuples) @@ -107,10 +105,25 @@ def get_transform_tree(tree, path=()): def get_muon_weight_dimension_numbers(model, config, verbose=False): """Extract muon dimension number from model structure.""" - # quickly get param structure without materialization - abstract_param = maxtext_utils.get_abstract_param(model, config) - # get muon dimension number from param - muon_weight_dimension_numbers = get_transform_tree(abstract_param) + + if isinstance(model, nnx.Module): + _, abstract_param, _ = nnx.split(model, nnx.Param, ...) + + def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): + # Convert jax.tree_util.KeyEntry path to Tuple[str, ...] + path_strings = tuple(p.key for p in path if isinstance(p, jax.tree_util.DictKey)) + return transform_logic(path_strings) + + # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. + # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) + + else: # Linen + # quickly get param structure without materialization + abstract_param = maxtext_utils.get_abstract_param(model, config) + # get muon dimension number from param + muon_weight_dimension_numbers = get_transform_tree(abstract_param) + if verbose: _print_structure_debug(abstract_param, muon_weight_dimension_numbers) return muon_weight_dimension_numbers @@ -118,19 +131,30 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): """Prints the model structure and the resulting Muon config.""" - # Access the shape from the inner ShapeDtypeStruct and names from the wrapper - # Return a new tree with the same structure containing only shapes/names + + def get_leaf_info(leaf): + # For linen: + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper + # Return a new tree with the same structure containing only shapes/names + if isinstance(leaf, nn.LogicallyPartitioned): + return {"shape": leaf.value.shape, "names": leaf.names} + # For nnx: + # Only return the shape because it doesn't have a wrapper. + elif isinstance(leaf, jax.ShapeDtypeStruct): + return {"shape": leaf.shape} + return {"shape": "N/A"} + info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + get_leaf_info, abstract_param, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + is_leaf=lambda x: isinstance(x, (nn.LogicallyPartitioned, jax.ShapeDtypeStruct)), ) print(f"\n=== Model Structure ===\n{info_tree}") print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -154,13 +178,17 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): f"model_name={model_name}", f"scan_layers={scan_layers}", "attention=dot_product", + f"pure_nnx={pure_nnx}", ] config = pyconfig.initialize(argv) # Setup model devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=mesh, quant=quant) + if pure_nnx: + _, model = model_creation_utils.create_nnx_abstract_model(config, mesh) + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) return muon_weight_dimension_numbers @@ -172,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True, pure_nnx=False) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 74b22548b0..29b00a0c10 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -15,7 +15,7 @@ # pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText and sharding related. """ -from flax import linen as nn +from flax import linen as nn, nnx from collections.abc import Iterable @@ -25,6 +25,7 @@ import optax +from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -468,6 +469,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): - updated_state_mesh_shardings: State mesh shardings with updated params field (unchanged if shard_optimizer_over_data is False) """ + if config.pure_nnx: + return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings) prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): @@ -486,6 +489,122 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return prev_params_shardings, state_mesh_shardings +def maybe_update_params_sharding_with_opt_nnx( + config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State +) -> tuple[nnx.State, nnx.State]: + """ + NNX version of parameter sharding update. Updates parameter sharding configuration + when optimizer state sharding is enabled. + + When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function + extracts the optimizer state shardings from the Adam optimizer's first moment (mu) + and merges them with the parameter shardings. This ensures parameter sharding is + consistent with how the optimizer state is distributed across the compute mesh. + + Args: + config: Configuration with shard_optimizer_over_data flag. + state_mesh_shardings: The sharding state for a TrainStateNNX container. + + Returns: + A tuple of (prev_params_shardings, updated_state_mesh_shardings): + - prev_params_shardings: Original parameter shardings before the update + - updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False)""" + # In TrainStateNNX, parameters are under 'model' + model_shardings = state_mesh_shardings.model + + def _extract_param_only(state): + """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + + Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree + structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + to work correctly between ga_params (Param-only) and params_shardings. + """ + result = {} + for k, v in state.items(): + if isinstance(v, nnx.Param): + result[k] = v + elif isinstance(v, nnx.Variable): + pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + elif hasattr(v, "items"): + sub = _extract_param_only(v) + if sub: + result[k] = sub + return result + + # prev_params_shardings must match the pytree structure of ga_params from + # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) + + if not config.shard_optimizer_over_data: + return prev_params_shardings, state_mesh_shardings + + sharded_fp32_params = None + # Check if the optimizer has any state at all (stateless optimizers like SGD omit this key) + if "opt_state" in state_mesh_shardings.optimizer: + # Access the optimizer branch to find the optax state + # state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer + opt_state = state_mesh_shardings.optimizer.opt_state + + def find_adam_mu(obj): + # 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX) + if isinstance(obj, optax.ScaleByAdamState): + return obj.mu + + # 2. Check for flattened ScaleByAdamState (nnx.State/dict) + # These nodes contain 'mu', 'nu', and 'count' as keys. + if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj: + return obj["mu"] + + # 3. Recursive search through containers (nnx.State, dict, list, tuple) + values = None + if hasattr(obj, "values"): # Handles nnx.State and dict + values = obj.values() + elif isinstance(obj, (list, tuple)): + values = obj + + if values: + for v in values: + res = find_adam_mu(v) + if res is not None: + return res + return None + + sharded_fp32_params = find_adam_mu(opt_state) + if sharded_fp32_params is None: + actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None")) + raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}") + + # Update model parameter sharding to match the mu (first moment) sharding. + # This ensures parameter sharding is consistent with the Zero-1 distributed layout. + # Build a path → new_PS lookup from sharded_fp32_params (mu), then update model_shardings + # at those paths while preserving rngs and any other non-Param variables. + mu_leaves_with_paths = list( + jax.tree_util.tree_leaves_with_path(sharded_fp32_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + ) + mu_lookup = {path: mu_var.get_value() for path, mu_var in mu_leaves_with_paths} + + def _update_model_var(path, var): + if path in mu_lookup: + return var.replace(mu_lookup[path]) + return var + + new_model_shardings = jax.tree_util.tree_map_with_path( + _update_model_var, model_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + # Use jax.tree_util.tree_map (identity) to create a new nnx.State via JAX's unflatten + # mechanism (not the nnx.State constructor). This is critical because: + # 1. nnx.State({...}) constructor recursively converts nested plain dicts to nnx.State, + # causing a pytree type mismatch with the actual state from nnx.split (which stores + # nested module states as plain dicts). JAX's unflatten preserves the original types. + # 2. copy.deepcopy fails because NamedSharding contains non-picklable jaxlib.Device objects. + # Direct __setattr__ assignment stores new_model_shardings as-is (no type conversion). + updated_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + updated_state.model = new_model_shardings + + return prev_params_shardings, updated_state + + def logical_axis_rules_pp_act_as_dp(logical_rules): """Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP. This is used when we want to pipeline only a subset of layers, and leave the rest like DP. diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 9413b099ed..2a55c8b6b1 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,12 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import functools import os from functools import partial import jax -import functools +from flax import nnx from flax.linen import partitioning as nn_partitioning +from maxtext.layers import train_state_nnx from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput @@ -202,7 +204,7 @@ def setup_train_loop(config, recorder, devices=None): data_iterator: data_loader: rampup_manager: the class managing rampup batch sizes - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance """ # pylint: disable=import-outside-toplevel from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator @@ -210,16 +212,22 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) + mesh = maxtext_utils.get_mesh_from_config(config, devices) if config.pure_nnx: # Create abstract NNX model. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) else: model = model_creation_utils.from_config(config, devices) - mesh = model.mesh learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # For NNX, the train state is wrapped in the TrainStateNNX module. + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) @@ -250,6 +258,15 @@ def setup_train_loop(config, recorder, devices=None): state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + with nn_partitioning.axis_rules(config.logical_axis_rules): + # train_state is instance of TrainStateNNX + state_graphdef, _ = nnx.get_abstract_model(init_state_fn, mesh) + _, state_params, _ = nnx.split(state.model, nnx.Param, ...) + _, state_mesh_shardings_params, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + else: + state_params = state.params + state_mesh_shardings_params = state_mesh_shardings.params if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): @@ -267,17 +284,24 @@ def setup_train_loop(config, recorder, devices=None): # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + sharding.assert_params_sufficiently_sharded(state_params, mesh, config.sharding_tolerance) # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + if config.pure_nnx: + # TODO: Study how to get logical annotations of NNX module. Because of eager sharding, we + # probably already lost the logical partition info at this moment. + logical_annotations_params = None + else: + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations_params = logical_annotations.params + max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params - ) + maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: + if config.pure_nnx: + raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -302,12 +326,18 @@ def setup_train_loop(config, recorder, devices=None): except FileNotFoundError: step0_restored = None if step0_restored is not None: + # TODO: For pure_nnx, the dpo state manipulation is different. reference_params = step0_restored["items"].params["params"] state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) + if config.pure_nnx: + train_state = nnx.merge(state_graphdef, state) + model = train_state.model + else: + train_state = state return ( init_rng, @@ -320,7 +350,7 @@ def setup_train_loop(config, recorder, devices=None): data_loader, rampup_manager, eval_data_iterator, - state, + train_state, ) diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 7a09750a86..395348a27e 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,11 +15,13 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any, Sequence from collections.abc import Callable +from typing import Any, Sequence import unittest from unittest.mock import MagicMock, Mock, patch from dataclasses import dataclass, field +import numpy as np +import optax from flax import linen as nn from flax import nnx @@ -29,6 +31,7 @@ from jax import random, vmap import jax.numpy as jnp from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils from maxtext.configs import pyconfig from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils @@ -39,8 +42,7 @@ from maxtext.utils import sharding from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides -import numpy as np -import optax +from maxtext.utils import maxtext_utils_nnx Transformer = models.transformer_as_linen @@ -179,11 +181,7 @@ def setUp(self): "decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}}, } self.state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.initial_params, - tx=None, - opt_state={}, + step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={} ) def test_update_mode_add(self): @@ -196,10 +194,10 @@ def test_update_mode_add(self): self.assertTrue(jnp.allclose(actual, expected)) # Other values are untouched - original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] + original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] # pylint: disable=unsubscriptable-object new_layer_0 = new_state.params["layers"]["layer_0"]["bias"] self.assertTrue(jnp.array_equal(original_layer_0, new_layer_0)) - original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] + original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] # pylint: disable=unsubscriptable-object new_layer_1 = new_state.params["layers"]["layer_1"]["bias"] self.assertTrue(jnp.array_equal(original_layer_1, new_layer_1)) @@ -264,7 +262,7 @@ def test_init_training_state(self): @nnx.register_variable_name("special_variables") -class SpecialVariables(nnx.Variable): +class SpecialVariables(nnx.Variable): # pylint: disable=abstract-method pass @@ -281,7 +279,7 @@ def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None): return x -class TrainState(train_state.TrainState): +class TrainState(train_state.TrainState): # pylint: disable=abstract-method other_variables: nnx.State @@ -993,49 +991,63 @@ def train_step(_model, _config, _state_shardings, _params_shardings, state, _bat return train_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_step() result = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(len(result), 5) def test_functional_train_has_correct_name(self): step = self._make_mock_step() fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(fn.__name__, "train_step") - def test_in_shardings_structure(self): + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" step = self._make_mock_step() _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=False) ) - # (state, batch, rng) self.assertEqual(len(in_shardings), 3) self.assertIsNone(in_shardings[2]) # rng sharding is None + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + def test_donate_argnums_is_zero(self): step = self._make_mock_step() _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(donate_argnums, 0) def test_functional_train_is_partial(self): """functional_train should partially apply model and config.""" received = {} + cfg = self._make_mock_config() def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): received["model"] = model received["config"] = config return state, {} - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", cfg) fn("state", "batch") self.assertEqual(received["model"], "my_model") - self.assertEqual(received["config"], "my_config") + self.assertEqual(received["config"], cfg) class TestGetFunctionalEvalWithSignature(unittest.TestCase): @@ -1047,26 +1059,51 @@ def eval_step(_model, _config, _state, _batch, _rng=None): return eval_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_eval_step() - result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(len(result), 5) def test_functional_eval_has_correct_name(self): step = self._make_mock_eval_step() - fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(fn.__name__, "eval_step") def test_out_shardings_is_none(self): step = self._make_mock_eval_step() - _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertIsNone(out_shardings) def test_donate_argnums_is_empty(self): step = self._make_mock_eval_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertEqual(donate_argnums, ()) + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + class TestGetShapedBatch(unittest.TestCase): """Tests for get_shaped_batch.""" @@ -1414,5 +1451,105 @@ def test_runs_without_logical_annotations(self): maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) +class TestNNXAbstractState(unittest.TestCase): + """Test the get_abstract_state_nnx func.""" + + @dataclass + class MockConfig: + init_weights_seed: int = 42 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + param_scan_axis: int = 0 + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + + class MockTrainState(nnx.Module): + """Simulates a TrainState with params and optimizer state.""" + + def __init__(self, rngs: nnx.Rngs): + # Model parameters + device_num = len(jax.local_devices()) + self.params = nnx.Linear( + 2, 4, kernel_init=nnx.with_partitioning(nnx.initializers.ones, sharding=("model",)), rngs=rngs + ) + # Simulated optimizer state + self.optimizer = nnx.Variable(jnp.zeros((device_num,)), sharding=("model",)) + + def setUp(self): + # Create a real 1D mesh on local devices + devices = jax.local_devices() + self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) + self.config = self.MockConfig() + + def nnx_init_trainstate_wrapper(self): + """Wrapper to initialize the mock NNX model.""" + rngs = maxtext_utils_nnx.create_nnx_rngs(self.config) + return self.MockTrainState(rngs) + + def test_basic_abstraction(self): + """Verifies the basic return structure and partition spec extraction.""" + abstract_state, annotations, shardings = maxtext_utils.get_abstract_state_nnx( + self.config, self.mesh, self.nnx_init_trainstate_wrapper + ) + + # Check return types + self.assertIsInstance(abstract_state, nnx.State) + self.assertIsInstance(annotations, nnx.State) + self.assertIsInstance(shardings, nnx.State) + + # Verify PartitionSpec was extracted correctly from the mock model's annotations + # Path: params -> kernel -> spec + self.assertEqual( + annotations.params.kernel.get_value(), + PartitionSpec( + "model", + ), + ) + + def test_shard_optimizer_over_data(self): + """Verifies that 'data' is added to optimizer sharding using the real utility.""" + self.config.shard_optimizer_over_data = True + + _, annotations, _ = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Original Pspec for optimizer was PartitionSpec(None). + # add_data_to_sharding should find that dim 0 is compatible with mesh 'data' + # and update it to PartitionSpec(('data',)). + opt_spec = annotations.optimizer.get_value() + + # Verify 'data' is now in the spec + self.assertEqual(opt_spec, PartitionSpec(("data", "model"))) + + def test_optimizer_host_offload(self): + """Verifies that optimizer memory is moved to host when configured.""" + self.config.optimizer_memory_host_offload = True + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Optimizer state should be pinned to host + opt_sharding = shardings.optimizer.get_value() + self.assertEqual(opt_sharding.memory_kind, "pinned_host") + + # Params should still be on default memory (usually device) + param_sharding = shardings.params.kernel.get_value() + self.assertNotEqual(param_sharding.memory_kind, "pinned_host") + + def test_parameter_host_offload(self): + """Verifies that parameter memory is moved to host when configured.""" + self.config.parameter_memory_host_offload = True + self.config.param_scan_axis = 0 + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Parameters should be pinned to host + param_sharding = shardings.params.kernel.get_value() + self.assertEqual(param_sharding.memory_kind, "pinned_host") + + def test_invalid_init_fn(self): + """Ensures function raises error if no init function is provided.""" + with self.assertRaises(AssertionError): + maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index c3a43970ba..f8266d8ed4 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -15,17 +15,18 @@ """ Unit tests for all optimizers. """ import re import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import jax +import jax.numpy as jnp import pytest from absl.testing import parameterized +from flax import nnx from optax.contrib import MuonDimensionNumbers as mdn from maxtext.configs import pyconfig from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils -from maxtext.utils.muon_utils import get_model_mdn +from maxtext.utils import maxtext_utils, muon_utils from tests.utils.test_helpers import get_test_config_path from typing import NamedTuple @@ -47,6 +48,7 @@ DEEPSEEK2_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -55,6 +57,7 @@ }, **_DEEPSEEK2_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -71,8 +74,6 @@ }, **_DEEPSEEK2_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -97,6 +98,7 @@ DEEPSEEK3_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -105,6 +107,7 @@ }, **_DEEPSEEK3_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -121,8 +124,6 @@ }, **_DEEPSEEK3_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -241,7 +242,7 @@ def test_model_integration(self, model_name, expected_output): Initializes the specified MaxText model and asserts that the generated Muon dimension numbers match the hardcoded reference. """ - actual_output = get_model_mdn(model_name, scan_layers=True) + actual_output = muon_utils.get_model_mdn(model_name, scan_layers=True, pure_nnx=False) self.assertEqual(actual_output, expected_output) @@ -428,5 +429,105 @@ def learning_rate_schedule(step): self.assertFalse(jax.numpy.all(updates["layer1"]["kernel"] == 0)) +class TestMuonLogic(unittest.TestCase): + """Tests the granular path transformation functions.""" + + def test_is_path_contain_any(self): + # pylint: disable=protected-access + self.assertTrue(muon_utils._is_path_contain_any(("a", "b"), ("x", "a", "z"))) + self.assertFalse(muon_utils._is_path_contain_any(("a", "b"), ("x", "y", "z"))) + + def test_transform_logic_exclusions(self): + self.assertIsNone(muon_utils.transform_logic(("layer_0", "bias"))) + self.assertIsNone(muon_utils.transform_logic(("layer_0", "scale"))) + self.assertIsNone(muon_utils.transform_logic(("embedding", "kernel"))) + + def test_transform_logic_moe(self): + path = ("layers_0", "MoeBlock_0", "wi_0") + result = muon_utils.transform_logic(path) + self.assertEqual(result.reduction_axis, (-2,)) + self.assertEqual(result.output_axis, (-1,)) + + def test_transform_logic_attention(self): + path_out = ("layers_0", "self_attention", "out", "kernel") + self.assertEqual(muon_utils.transform_logic(path_out), mdn((0, -2), (-1,))) + + path_q = ("layers_0", "self_attention", "query", "kernel") + self.assertEqual(muon_utils.transform_logic(path_q), mdn((0,), (-2, -1))) + + def test_get_transform_tree(self): + fake_tree = {"params": {"layer_0": {"kernel": "leaf", "bias": "leaf"}, "MoeBlock_0": {"wi_0": "leaf"}}} + result = muon_utils.get_transform_tree(fake_tree) + self.assertEqual(result["params"]["layer_0"]["kernel"], mdn((0,), (-1,))) + self.assertIsNone(result["params"]["layer_0"]["bias"]) + + def test_get_muon_weight_dimension_numbers_nnx(self): + """Verifies dimension extraction for stateful NNX modules.""" + + class MockNNXModel(nnx.Module): + """Mock NNX Module.""" + + def __init__(self, rngs: nnx.Rngs): + # 1. Standard layer + self.layer1 = nnx.Linear(2, 4, rngs=rngs) + + # 2. MoE specific naming to trigger transform logic. + # The logic expects "MoeBlock_0" AND "wi_0"/"wi_1"/"wo" in the path. + # We nest the linear layer to create the path: ('MoeBlock_0', 'wi_0', 'kernel') + self.MoeBlock_0 = nnx.Module() + self.MoeBlock_0.wi_0 = nnx.Linear(4, 2, rngs=rngs) + + # 3. Exclusion case (scaler/scale) + self.scale = nnx.Param(jnp.ones((1,))) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: MockNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + + # Extract dimension numbers using the NNX path in muon_utils + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Verify standard weight path: ('layer1', 'kernel') -> default (0,) + self.assertEqual(result.layer1.kernel.value, mdn((0,), (-1,))) + + # Verify MoE weight path: ('MoeBlock_0', 'wi_0', 'kernel') -> (-2,) + self.assertEqual(result.MoeBlock_0.wi_0.kernel.value, mdn((-2,), (-1,))) + + # Verify exclusion (scalar/scale) + self.assertIsNone(result.scale.value) + + def test_verbose_output_nnx(self): + """Covers lines 128 and 135-154: _print_structure_debug via verbose=True with NNX model.""" + + class SimpleNNXModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + model = nnx.eval_shape(lambda: SimpleNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + muon_utils.get_muon_weight_dimension_numbers(model, config, verbose=True) + + def test_nnx_deepseek_attention_logic(self): + """Simulates a DeepSeek-like attention structure in NNX.""" + + class DeepSeekAttention(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.self_attention = nnx.Module() + self.self_attention.query = nnx.Linear(8, 8, rngs=rngs) + self.self_attention.out = nnx.Linear(8, 8, rngs=rngs) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: DeepSeekAttention(nnx.Rngs(0))) + config = MagicMock() + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Check attention query: [0] -> [-2, -1] + self.assertEqual(result.self_attention.query.kernel.value, mdn((0,), (-2, -1))) + # Check attention out: [0, -2] -> [-1] + self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py new file mode 100644 index 0000000000..53318469fa --- /dev/null +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -0,0 +1,291 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX checkpoint tests.""" + +import pathlib +import tempfile +import shutil + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx, serialization +from flax import linen as nn +from flax.training import train_state +import optax +import orbax.checkpoint as ocp + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """A simple model for checkpoint testing.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class LinenMockModel(nn.Module): + """The Linen equivalent of the MockModel.""" + + @nn.compact + def __call__(self, x): + # We name the layer 'linear' to match the attribute name in the NNX MockModel + return nn.Dense(features=1, name="linear")(x) + + +class TestTrainStateNNXCheckpoint(unittest.TestCase): + """Class to test NNX checkpoint.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + + # Setup a chained optimizer: Gradient Clipping -> Adam + # Note: optax.adam is also a chain (scale_by_adam + scale_by_learning_rate). + # This creates a nested state structure: (EmptyState, (ScaleByAdamState, EmptyState)) + self.tx = optax.chain( + optax.clip_by_global_norm(max_norm=1.0), + optax.adam(1e-3), + ) + + def test_checkpoint_structure(self): + """Ensures the state object contains both model and optimizer keys.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # We use .to_pure_dict() to simulate the format stored in a checkpoint. + # This converts nnx.Variable/State objects into raw arrays and dictionaries. + full_state = nnx.state(state).to_pure_dict() + + # 1. Verify Top-level Keys + self.assertIn("model", full_state) + self.assertIn("optimizer", full_state) + + # 2. Verify Optimizer Internal Structure + opt_inner_state = full_state["optimizer"]["opt_state"] + + # Because we used optax.chain(clip, adam), index 0 is clip, index 1 is adam. + # Since adam is also a chain, index 1 is itself a dictionary/tuple representation. + # Adam's momentum (mu/nu) is in the first element of its own sub-chain. + adam_component = opt_inner_state[1][0] + + self.assertIn("mu", adam_component, "Adam 'mu' buffer not found in pure dict state.") + self.assertIn("nu", adam_component, "Adam 'nu' buffer not found in pure dict state.") + + # In a pure dict, these are nested dictionaries containing arrays, not NNX objects. + self.assertIsInstance(adam_component["mu"], dict) + self.assertIsInstance(adam_component["nu"], dict) + + # To verify a specific leaf, we navigate the dictionary hierarchy: + self.assertIsInstance(adam_component["mu"]["linear"]["kernel"], jax.Array) + + def test_checkpoint_and_restore(self): + """Verifies that the full state can be captured and restored into a new instance.""" + # 1. Initialize original state and optimizer + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state_original = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # 2. Perform a training step to modify weights and optimizer buffers + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state_original.model) + state_original.apply_gradients(grads) + + # Capture state after one step + original_kernel_val = state_original.model.linear.kernel.value + original_step_val = state_original.optimizer.step.value + self.assertEqual(original_step_val, 1) + + # 3. Capture the "Checkpoint" as a pure dictionary + checkpoint_state = nnx.state(state_original).to_pure_dict() + + # 4. Initialize a fresh, different instance + new_rngs = nnx.Rngs(1) + new_model = MockModel(rngs=new_rngs) + new_optimizer = nnx.Optimizer(new_model, self.tx, wrt=nnx.Param) + state_restored = train_state_nnx.TrainStateNNX(new_model, new_optimizer) + + # Check differences before restoration + self.assertEqual(state_restored.optimizer.step.value, 0) + self.assertFalse(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # 5. Restore the state into the new instance. + # nnx.update supports updating from a pure dictionary. + nnx.update(state_restored, checkpoint_state) + + # 6. Verify restoration + # Check step counter + self.assertEqual(state_restored.optimizer.step.value, original_step_val) + # Check model weights + self.assertTrue(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # Check that it can still be trained after restoration + new_grads = nnx.grad(loss_fn)(state_restored.model) + state_restored.apply_gradients(new_grads) + self.assertEqual(state_restored.optimizer.step.value, 2) + + def test_restore_from_linen_state(self): + """Verifies a multi-stage migration: Linen CKPT -> Migrate -> NNX CKPT -> Restore.""" + # 1. Setup Linen TrainState (Simulating original training) + linen_model = LinenMockModel() + dummy_input = jnp.ones((1, 2)) + variables = linen_model.init(jax.random.key(42), dummy_input) + + state_linen = train_state.TrainState.create(apply_fn=linen_model.apply, params=variables["params"], tx=self.tx) + + # Perform a step to populate optimizer buffers + grads = jax.tree.map(jnp.ones_like, state_linen.params) + state_linen = state_linen.apply_gradients(grads=grads) + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save Legacy Linen Checkpoint --- + linen_ckpt_dir = temp_dir / "linen_ckpt" + mngr_linen = ocp.CheckpointManager( + linen_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr_linen.save(0, args=ocp.args.StandardSave(state_linen)) + mngr_linen.wait_until_finished() + + # --- PHASE 2: Read Linen CKPT and Convert to NNX Structure --- + # Load it back without knowing the blueprint (reading as a pure PyTree) + restored_linen_obj = mngr_linen.restore(0) + + # Convert the restored object to a pure dictionary structure. + restored_linen_dict = serialization.to_state_dict(restored_linen_obj) + + # Helper to recursively convert string keys back to integers + # and filter out None values. + def recursive_clean(obj): + if isinstance(obj, dict): + return {int(k) if k.isdigit() else k: recursive_clean(v) for k, v in obj.items() if v is not None} + return obj + + # Converted dict - simple PyTree mapping, no NNX Module initialization needed here. + # This simulates a situation where the conversion logic is blueprint-agnostic. + linen_as_nnx_dict = { + "model": restored_linen_dict["params"], + "optimizer": { + "step": jnp.array(restored_linen_dict["step"]), + "opt_state": recursive_clean(restored_linen_dict["opt_state"]), + }, + } + + # --- PHASE 3: Save as Native NNX Checkpoint --- + nnx_ckpt_dir = temp_dir / "nnx_ckpt" + mngr_nnx = ocp.CheckpointManager( + nnx_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + # We save the raw dictionary directly to disk. + mngr_nnx.save(0, args=ocp.args.StandardSave(linen_as_nnx_dict)) + mngr_nnx.wait_until_finished() + + # --- PHASE 4: Restore from NNX Checkpoint to target Model --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We now restore using the nnx.State as a blueprint. This ensures Orbax + # correctly maps the arrays on disk to the model's structural expectation. + blueprint = nnx.state(state_nnx).to_pure_dict() + restored_nnx_pytree = mngr_nnx.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + nnx.update(state_nnx, restored_nnx_pytree) + + # --- PHASE 5: Verification --- + # 1. Verify Step + self.assertEqual(state_nnx.optimizer.step.value, 1) + + # 2. Verify Weights + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, state_linen.params["linear"]["kernel"])) + + # 3. Verify Chained Optimizer State (Clip at index 0, Adam at index 1) + self.assertEqual(type(state_nnx.optimizer.opt_state[0]), type(state_linen.opt_state[0])) + + # state_linen.opt_state[1] is the Adam chain state. + # state_linen.opt_state[1][0] is the ScaleByAdamState containing 'mu'. + self.assertTrue( + jnp.allclose( + state_nnx.optimizer.opt_state[1][0].mu["linear"]["kernel"], + state_linen.opt_state[1][0].mu["linear"]["kernel"], + ) + ) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + def test_restore_from_checkpoint_model_params(self): + """Verifies that model parameters can be restored from model params only.""" + # 1. Setup mocked parameters manually (no Linen model needed for setup) + # This structure matches the path model.linear.kernel/bias in the NNX MockModel. + mock_params = {"linear": {"kernel": jnp.ones((2, 1)) * 9.0, "bias": jnp.zeros((1,))}} + + # Simplified checkpoint dictionary using hardcoded mocked params as requested + checkpoint_dict = { + "model": mock_params, + } + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save the partial checkpoint --- + mngr = ocp.CheckpointManager( + temp_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr.save(0, args=ocp.args.StandardSave(checkpoint_dict)) + mngr.wait_until_finished() + + # --- PHASE 2: Restore into a full TrainStateNNX --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We use nnx.state to get a full blueprint as a reference. + full_nnx_pure_dict = nnx.state(state_nnx).to_pure_dict() + blueprint = {"model": full_nnx_pure_dict["model"]} + + # If we don't know if the checkpoint on disk has 'optimizer' or not, we simulate + # schema-agnostic restoration by calling restore without a blueprint. + # This avoids Orbax structural mismatch errors while allowing us to see the data. + restored_pytree = mngr.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + + # Use nnx.update to apply the restored data to the stateful NNX object. + # nnx.update is naturally partial: it will update 'model' from the restored dict + # and leave 'optimizer' untouched at its initialized value. + nnx.update(state_nnx, restored_pytree) + + # --- PHASE 3: Verification --- + # Check that weights were restored to the specific mock values + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, mock_params["linear"]["kernel"])) + # Step remains at its initialized value (0) because it was not in the checkpoint + self.assertEqual(state_nnx.optimizer.step.value, 0) + + # Verify that the optimizer state still exists in the object (initialized) + # even though it was not provided in the checkpoint. + # Adam's state is at index 1 of the chain, and it's a nested structure (tuple). + # We verify that index 0 (ScaleByAdamState) contains the 'mu' State container. + self.assertIsInstance(state_nnx.optimizer.opt_state[1][0].mu, nnx.State) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_test.py b/tests/unit/train_state_nnx_test.py new file mode 100644 index 0000000000..03db77ff63 --- /dev/null +++ b/tests/unit/train_state_nnx_test.py @@ -0,0 +1,90 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX tests.""" + +import unittest +import jax.numpy as jnp +from flax import nnx +import optax + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """Mocked NNX model""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class TestTrainStateNNX(unittest.TestCase): + """TrainStateNNX tests.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + self.tx = optax.adam(1e-3) + + def test_init_with_optimizer(self): + """Test init with iptimizer.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + self.assertEqual(state.model, self.model) + self.assertEqual(state.optimizer, optimizer) + # Access step directly from optimizer + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_optimizer(self): + """Test init without optimizer.""" + state = train_state_nnx.TrainStateNNX(self.model, None) + + self.assertEqual(state.model, self.model) + self.assertIsNone(state.optimizer) + + def test_apply_gradients_success(self): + """Test apply gradients can be called successfully.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # Create dummy gradients matching the model state structure + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state.model) + + # Apply gradients + state.apply_gradients(grads) + + # Verify step incremented (managed by nnx.Optimizer) + self.assertEqual(state.optimizer.step.value, 1) + + def test_apply_gradients_raises_runtime_error(self): + """Test apply gradients without a optimizer.""" + # Initialize without optimizer (inference mode) + state = train_state_nnx.TrainStateNNX(self.model, None) + + dummy_grads = {} + with self.assertRaises(RuntimeError) as cm: + state.apply_gradients(dummy_grads) + + self.assertIn("inference only", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main()