Skip to content

NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop#3500

Draft
ecnal-cienet wants to merge 3 commits intomainfrom
feat/nnx-trainstate-and-training-loop
Draft

NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop#3500
ecnal-cienet wants to merge 3 commits intomainfrom
feat/nnx-trainstate-and-training-loop

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 25, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. 🔄 [This PR] NNX fully supported end-to-end. pure_nnx=True enables full NNX training; default remains False.
  4. ❌ NNX unit tests and performance verification complete. Set pure_nnx=True as default.
  5. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the third in a series of NNX migration PRs. With this PR, pure_nnx=True runs a complete NNX training loop — initialization, sharding, gradient accumulation, eval, and checkpointing — without hitting any NotImplementedError. The pure_nnx flag still defaults to False, preserving the existing Linen workflow unchanged.

TrainStateNNX and unit tests

src/maxtext/layers/train_state_nnx.py implements the TrainStateNNX container, which holds an NNX model and its Optax optimizer as a single composable unit. Unit tests cover state creation, optimizer step, and Orbax checkpoint round-trip:

  • tests/unit/train_state_nnx_test.py
  • tests/unit/train_state_nnx_checkpoint_test.py

Muon optimizer and model creation utilities

  • muon_utils.py — updated to support NNX models alongside Linen.
  • model_creation_utils.py — refactored to expose create_nnx_abstract_model and from_config, which create and initialize an NNX model from a config without running a full forward pass.

End-to-end training loop (train.py + supporting modules)

The core training loop in train.py now dispatches on pure_nnx at every major decision point:

  • Sharding (sharding.py) — maybe_update_params_sharding_with_opt dispatches to a new maybe_update_params_sharding_with_opt_nnx, which extracts nnx.Param-only shardings from the flat nnx.State without accessing .params.
  • Gradient accumulation (gradient_accumulation.py) — NNX path uses nnx.value_and_grad with nnx.split / nnx.merge per microbatch inside jax.lax.scan, carrying non-Param rest state (RNGs) through the loop.
  • Train/eval step JIT (maxtext_utils.py) — get_functional_train_with_signature and get_functional_eval_with_signature use a 2-element in_shardings tuple (state, batch) for NNX (no rng argument), vs. 3-element for Linen.
  • Checkpointing (checkpointing.py) — maybe_save_checkpoint converts nnx.State to a plain dict via state.to_pure_dict() before Orbax save; load_state_if_possible restores via nnx.replace_by_pure_dict(abstract_state, dict).

Tests

Unit tests:

python3 -m pytest tests/unit/train_state_nnx_test.py -v
python3 -m pytest tests/unit/train_state_nnx_checkpoint_test.py -v
python3 -m pytest tests/unit/sharding_nnx_test.py -v
python3 -m pytest tests/unit/maxtext_utils_nnx_test.py -v
python3 -m pytest tests/unit/train_compile_test.py -v

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 3 times, most recently from 4bae533 to e6baabd Compare March 25, 2026 21:48
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch 2 times, most recently from 754df44 to 8055cc8 Compare March 26, 2026 17:10
@ecnal-cienet ecnal-cienet changed the title Feat/nnx trainstate and training loop NNX migration prep (3/N): Feat/nnx trainstate and training loop Mar 26, 2026
@ecnal-cienet ecnal-cienet changed the title NNX migration prep (3/N): Feat/nnx trainstate and training loop NNX migration prep (3/N): TrainState, model creation, and end-to-end training loop Mar 27, 2026
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-trainstate-and-training-loop branch from 8055cc8 to a906f15 Compare March 30, 2026 14:49
xibinliu and others added 3 commits March 31, 2026 13:59
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants