Skip to content

NNX migration prep (4/N): sharding tools and Linen<->NNX checkpoint utilities#3525

Draft
ecnal-cienet wants to merge 4 commits intomainfrom
feat/nnx-linen-converter-and-sharding-tools
Draft

NNX migration prep (4/N): sharding tools and Linen<->NNX checkpoint utilities#3525
ecnal-cienet wants to merge 4 commits intomainfrom
feat/nnx-linen-converter-and-sharding-tools

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 31, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. 🔄 [This PR] NNX sharding diagnostics and bidirectional Linen↔NNX checkpoint conversion utilities.
  5. ❌ NNX unit tests and performance verification complete. Set pure_nnx=True as default.
  6. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the fourth in a series of NNX migration PRs. This PR adds developer tooling to inspect NNX sharding and convert / compare checkpoints across Linen and NNX formats. No training logic is changed.

Sharding diagnostics

  • maxtext_utils.pyprint_shardings_params now dispatches on pure_nnx: for NNX models it iterates over the flat nnx.State rather than the Linen params tree.
  • tests/utils/run_sharding_dump.pyrun_single_dump() now propagates --pure_nnx=true to the sharding-dump subprocess when the flag is set, enabling NNX sharding dumps without manual flag threading.

Linen ↔ NNX checkpoint converter

src/maxtext/checkpoint_conversion/linen_nnx_converter.py — a standalone CPU-only script that bidirectionally converts Orbax checkpoints between Linen and NNX formats.

Key transformations handled:

Direction params tree opt_state step Layer layout
Linen → NNX params/params/<model>model/<model> + {value:} wrappers remove params level from mu/nu move inside optimizer/ stack layers_N arrays → layers tensor (axis 1)
NNX → Linen reverse of above add params level move to top level unstack layers tensor → layers_N per-layer arrays

--direction accepts linen_to_nnx, nnx_to_linen, or auto (detects format from checkpoint keys).

Checkpoint comparison utility

src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py — compares tree structure, shapes, and optionally values between any two Orbax checkpoints (Linen vs NNX, or same-format). Auto-detects format and applies cross-format normalization (layer axis transposition, {value:} unwrapping, RNG filtering) only when needed.

# Structure + shape comparison (Linen vs NNX)
python compare_linen_nnx_checkpoint.py \
  --ckpt_path_1="gs://bucket/linen_checkpoint/0/items" \
  --ckpt_path_2="gs://bucket/nnx_checkpoint/0/items"

# Value comparison
python compare_linen_nnx_checkpoint.py \
  --ckpt_path_1="gs://bucket/ckpt_a/0/items" \
  --ckpt_path_2="gs://bucket/ckpt_b/0/items" \
  --compare_values --atol=1e-5 --rtol=1e-5

Tests

Unit tests:

python3 -m pytest tests/unit/linen_nnx_converter_test.py -v
python3 -m pytest tests/unit/compare_linen_nnx_checkpoint_test.py -v

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch from 63adaef to 1fe3f78 Compare March 31, 2026 15:08
xibinliu and others added 4 commits April 1, 2026 16:20
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…ison utility

- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-linen-converter-and-sharding-tools branch from 1fe3f78 to 8aa93c4 Compare April 1, 2026 16:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants