-
Notifications
You must be signed in to change notification settings - Fork 123
Add docs for serialization with NX.serialize #630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # Saving and Loading Models | ||
|
|
||
| ## Section | ||
|
|
||
| ```elixir | ||
| Mix.install([ | ||
| {:axon, "~> 0.8"} | ||
| ]) | ||
| ``` | ||
|
|
||
| ## Overview | ||
|
|
||
| Axon recommends a **parameters-only** approach to saving models: serialize only the trained parameters (weights) using `Nx.serialize/2` and `Nx.deserialize/2`, and keep the model definition in your code. This approach: | ||
|
|
||
| * Avoids serialization issues with anonymous functions and complex model structures | ||
| * Makes the model structure explicit and version-controlled in code | ||
| * Works reliably across processes and deployments | ||
|
|
||
| The model itself is just code, you define it once and reuse it. Only the learned parameters need to be persisted. | ||
|
|
||
| ## Saving a Model After Training | ||
|
|
||
| When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them: | ||
|
|
||
| ```elixir | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(4) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd) | ||
|
|
||
| train_data = | ||
| Stream.repeatedly(fn -> | ||
| {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1}) | ||
| {xs, Nx.sin(xs)} | ||
| end) | ||
|
|
||
| trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100) | ||
| ``` | ||
|
|
||
| The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `data` field from `ModelState`: | ||
|
|
||
| ```elixir | ||
| # Extract parameters - trained_model_state.data contains the nested map of weights | ||
| params = trained_model_state.data | ||
|
|
||
| # Serialize and save | ||
| params_bytes = Nx.serialize(params) | ||
| File.write!("model_params.axon", params_bytes) | ||
| ``` | ||
|
|
||
| ## Loading a Model for Inference | ||
|
|
||
| To load and run inference, you need: | ||
|
|
||
| 1. The model definition (in code—the same structure you trained) | ||
| 2. The saved parameters | ||
|
|
||
| ```elixir | ||
| # 1. Define the same model structure (must match training) | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(4) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| # 2. Load parameters | ||
| params = File.read!("model_params.axon") |> Nx.deserialize() | ||
|
|
||
| # 3. Run inference | ||
| input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input) | ||
| Axon.predict(model, params, %{"data" => input}) | ||
| ``` | ||
|
|
||
| ## Checkpointing During Training | ||
|
|
||
| To save checkpoints during training (e.g., every epoch or when validation improves), use `Axon.Loop.checkpoint/2`. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later. | ||
|
|
||
| ```elixir | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd) | ||
| |> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed) | ||
|
|
||
| train_data = | ||
| Stream.repeatedly(fn -> | ||
| {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1}) | ||
| {xs, Nx.sin(xs)} | ||
| end) | ||
|
|
||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50) | ||
| ``` | ||
|
|
||
| Checkpoints are saved to the `checkpoints/` directory, as configured above. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`. | ||
|
|
||
| ## Resuming from a Checkpoint | ||
|
|
||
| To resume training from a saved checkpoint: | ||
|
|
||
| 1. Load the checkpoint with `Axon.Loop.deserialize_state/2` | ||
| 2. Attach it to your loop with `Axon.Loop.from_state/2` | ||
| 3. Run the loop as usual | ||
|
|
||
| ```elixir | ||
| # Load the checkpoint (use the path from your checkpoint files) | ||
| checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" | ||
| serialized = File.read!(checkpoint_path) | ||
| state = Axon.Loop.deserialize_state(serialized) | ||
|
|
||
| # Resume training | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd) | ||
| |> Axon.Loop.from_state(state) | ||
|
|
||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50) | ||
| ``` | ||
|
|
||
| ## Saving Only Parameters from a Checkpoint | ||
|
|
||
| If you have a checkpoint file and want to extract parameters for inference (without optimizer state): | ||
|
|
||
| ```elixir | ||
| checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" | ||
| state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state() | ||
|
|
||
| # Extract model parameters from step_state | ||
| %{model_state: model_state} = state.step_state | ||
| params = model_state.data | ||
|
|
||
| # Save for inference | ||
| File.write!("model_params.axon", Nx.serialize(params)) | ||
| ``` | ||
|
|
||
| ## Summary | ||
|
|
||
| | Use Case | Save | Load | | ||
| | ------------------------------ | --------------------------------------------------------- | ---------------------------------------------------------- | | ||
| | Inference only | `Nx.serialize(params)` → file | `Nx.deserialize(file)` + model in code | | ||
| | Checkpoint (resume training) | `Axon.Loop.checkpoint/2` or `Axon.Loop.serialize_state/2` | `Axon.Loop.deserialize_state/2` + `Axon.Loop.from_state/2` | | ||
| | Extract params from checkpoint | `state.step_state.model_state.data` → `Nx.serialize` | Use with model in code | | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| defmodule Axon.SerializationGuideTest do | ||
| @moduledoc """ | ||
| Tests that validate the examples in guides/serialization/saving_and_loading.livemd. | ||
| Run with: mix test test/axon/serialization_guide_test.exs | ||
| """ | ||
| use Axon.Case, async: false | ||
|
|
||
| @tmp_path Path.join( | ||
| System.tmp_dir!(), | ||
| "axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}" | ||
| ) | ||
|
|
||
| setup do | ||
| File.mkdir_p!(@tmp_path) | ||
| on_exit(fn -> File.rm_rf!(@tmp_path) end) | ||
| [tmp_path: @tmp_path] | ||
| end | ||
|
|
||
| describe "saving and loading guide examples" do | ||
| test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do | ||
| # Same model as the guide | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(4) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0) | ||
|
|
||
| train_data = | ||
| Stream.repeatedly(fn -> | ||
| {xs, _} = | ||
| Nx.Random.normal( | ||
| Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})), | ||
| shape: {8, 1} | ||
| ) | ||
|
|
||
| {xs, Nx.sin(xs)} | ||
| end) | ||
|
|
||
| # Train | ||
| trained_model_state = | ||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50) | ||
|
|
||
| # Extract and save params (as in guide) | ||
| params = | ||
| case trained_model_state do | ||
| %Axon.ModelState{data: data} -> data | ||
| params when is_map(params) -> params | ||
| end | ||
|
|
||
| params_path = Path.join(tmp_path, "model_params.axon") | ||
| params = Nx.backend_transfer(params) | ||
| params_bytes = Nx.serialize(params) | ||
| File.write!(params_path, params_bytes) | ||
|
|
||
| # Load and predict (input shape must match training: {batch, 1} for 1 feature) | ||
| loaded_params = File.read!(params_path) |> Nx.deserialize() | ||
| input = Nx.tensor([[1.0]]) | ||
|
|
||
| prediction = Axon.predict(model, loaded_params, %{"data" => input}) | ||
|
|
||
| assert Nx.rank(prediction) == 2 | ||
| assert Nx.shape(prediction) == {1, 1} | ||
| end | ||
|
|
||
| test "checkpoint and resume flow", %{tmp_path: tmp_path} do | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(4) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| checkpoint_path = Path.join(tmp_path, "checkpoints") | ||
|
|
||
| loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) | ||
| |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) | ||
|
|
||
| train_data = [ | ||
| {Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])}, | ||
| {Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])} | ||
| ] | ||
|
|
||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2) | ||
|
|
||
| # Verify checkpoint was saved | ||
| ckpt_files = File.ls!(checkpoint_path) |> Enum.sort() | ||
| assert length(ckpt_files) == 2 | ||
| assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_")) | ||
|
|
||
| # Load checkpoint and extract params for inference | ||
| ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files)) | ||
| state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state() | ||
|
|
||
| %{model_state: model_state} = state.step_state | ||
| params = model_state.data | ||
|
|
||
| # Run inference with extracted params | ||
| input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) | ||
| prediction = Axon.predict(model, params, %{"data" => input}) | ||
|
|
||
| assert Nx.rank(prediction) == 2 | ||
| assert Nx.shape(prediction) == {1, 1} | ||
| end | ||
|
|
||
| test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(2) | ||
| |> Axon.dense(1) | ||
|
|
||
| checkpoint_path = Path.join(tmp_path, "checkpoints_resume") | ||
|
|
||
| loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) | ||
| |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) | ||
|
|
||
| train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}] | ||
|
|
||
| # Run for 1 epoch | ||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1) | ||
|
|
||
| # Load checkpoint and resume | ||
| [ckpt_file] = File.ls!(checkpoint_path) | ||
| state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state() | ||
|
|
||
| resumed_loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) | ||
| |> Axon.Loop.from_state(state) | ||
|
|
||
| # Resume - should complete without error | ||
| result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2) | ||
|
|
||
| assert %Axon.ModelState{} = result | ||
| end | ||
| end | ||
| end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seanmor5 I think there's a bit of a dissonance between not having Axon.serialize/deserialize, while checkpoints need their Axon functions. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBC