Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

## v0.7.0 (2024-10-08)

### Breaking Changes

* **Removed `Axon.serialize/2` and `Axon.deserialize/2`** — Use `Nx.serialize/2` and `Nx.deserialize/2` for parameters instead. Axon recommends serializing only the trained parameters (weights) and keeping the model definition in code. See the [Saving and Loading](guides/serialization/saving_and_loading.livemd) guide.

### Bug Fixes

* Do not cast integers in in Axon.MixedPrecision.cast/2
Expand Down
2 changes: 1 addition & 1 deletion guides/guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ Axon is a library for creating and training neural networks in Elixir. The Axon

## Serialization

* [Converting ONNX models to Axon](serialization/onnx_to_axon.livemd)
* [Saving and loading models](serialization/saving_and_loading.livemd)

159 changes: 159 additions & 0 deletions guides/serialization/saving_and_loading.livemd
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Saving and Loading Models

## Section

```elixir
Mix.install([
{:axon, "~> 0.8"}
])
```

## Overview

Axon recommends a **parameters-only** approach to saving models: serialize only the trained parameters (weights) using `Nx.serialize/2` and `Nx.deserialize/2`, and keep the model definition in your code. This approach:

* Avoids serialization issues with anonymous functions and complex model structures
* Makes the model structure explicit and version-controlled in code
* Works reliably across processes and deployments

The model itself is just code, you define it once and reuse it. Only the learned parameters need to be persisted.

## Saving a Model After Training

When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them:

```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)

train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
{xs, Nx.sin(xs)}
end)

trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100)
```

The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `data` field from `ModelState`:

```elixir
# Extract parameters - trained_model_state.data contains the nested map of weights
params = trained_model_state.data

# Serialize and save
params_bytes = Nx.serialize(params)
File.write!("model_params.axon", params_bytes)
```

## Loading a Model for Inference

To load and run inference, you need:

1. The model definition (in code—the same structure you trained)
2. The saved parameters

```elixir
# 1. Define the same model structure (must match training)
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

# 2. Load parameters
params = File.read!("model_params.axon") |> Nx.deserialize()

# 3. Run inference
input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input)
Axon.predict(model, params, %{"data" => input})
```

## Checkpointing During Training

To save checkpoints during training (e.g., every epoch or when validation improves), use `Axon.Loop.checkpoint/2`. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later.

```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed)

train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
{xs, Nx.sin(xs)}
end)

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50)
```

Checkpoints are saved to the `checkpoints/` directory, as configured above. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.

## Resuming from a Checkpoint

To resume training from a saved checkpoint:

1. Load the checkpoint with `Axon.Loop.deserialize_state/2`
2. Attach it to your loop with `Axon.Loop.from_state/2`
3. Run the loop as usual

```elixir
# Load the checkpoint (use the path from your checkpoint files)
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
serialized = File.read!(checkpoint_path)
state = Axon.Loop.deserialize_state(serialized)

# Resume training
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)

Comment on lines +108 to +128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 I think there's a bit of a dissonance between not having Axon.serialize/deserialize, while checkpoints need their Axon functions. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBC

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.from_state(state)

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50)
```

## Saving Only Parameters from a Checkpoint

If you have a checkpoint file and want to extract parameters for inference (without optimizer state):

```elixir
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state()

# Extract model parameters from step_state
%{model_state: model_state} = state.step_state
params = model_state.data

# Save for inference
File.write!("model_params.axon", Nx.serialize(params))
```

## Summary

| Use Case | Save | Load |
| ------------------------------ | --------------------------------------------------------- | ---------------------------------------------------------- |
| Inference only | `Nx.serialize(params)` → file | `Nx.deserialize(file)` + model in code |
| Checkpoint (resume training) | `Axon.Loop.checkpoint/2` or `Axon.Loop.serialize_state/2` | `Axon.Loop.deserialize_state/2` + `Axon.Loop.from_state/2` |
| Extract params from checkpoint | `state.step_state.model_state.data` → `Nx.serialize` | Use with model in code |
2 changes: 1 addition & 1 deletion lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ defmodule Axon.Loop do

It is the opposite of `Axon.Loop.serialize_state/2`.

By default, the step state is deserialized using `Nx.deserialize.2`;
By default, the step state is deserialized using `Nx.deserialize/2`;
however, this behavior can be changed if step state is an application
specific container. For example, if you introduce your own data
structure into step_state and you customized the serialization logic,
Expand Down
143 changes: 143 additions & 0 deletions test/axon/serialization_guide_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
defmodule Axon.SerializationGuideTest do
@moduledoc """
Tests that validate the examples in guides/serialization/saving_and_loading.livemd.
Run with: mix test test/axon/serialization_guide_test.exs
"""
use Axon.Case, async: false

@tmp_path Path.join(
System.tmp_dir!(),
"axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}"
)

setup do
File.mkdir_p!(@tmp_path)
on_exit(fn -> File.rm_rf!(@tmp_path) end)
[tmp_path: @tmp_path]
end

describe "saving and loading guide examples" do
test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do
# Same model as the guide
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0)

train_data =
Stream.repeatedly(fn ->
{xs, _} =
Nx.Random.normal(
Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})),
shape: {8, 1}
)

{xs, Nx.sin(xs)}
end)

# Train
trained_model_state =
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50)

# Extract and save params (as in guide)
params =
case trained_model_state do
%Axon.ModelState{data: data} -> data
params when is_map(params) -> params
end

params_path = Path.join(tmp_path, "model_params.axon")
params = Nx.backend_transfer(params)
params_bytes = Nx.serialize(params)
File.write!(params_path, params_bytes)

# Load and predict (input shape must match training: {batch, 1} for 1 feature)
loaded_params = File.read!(params_path) |> Nx.deserialize()
input = Nx.tensor([[1.0]])

prediction = Axon.predict(model, loaded_params, %{"data" => input})

assert Nx.rank(prediction) == 2
assert Nx.shape(prediction) == {1, 1}
end

test "checkpoint and resume flow", %{tmp_path: tmp_path} do
model =
Axon.input("data")
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

checkpoint_path = Path.join(tmp_path, "checkpoints")

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
|> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed)

train_data = [
{Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])},
{Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])}
]

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2)

# Verify checkpoint was saved
ckpt_files = File.ls!(checkpoint_path) |> Enum.sort()
assert length(ckpt_files) == 2
assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_"))

# Load checkpoint and extract params for inference
ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files))
state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state()

%{model_state: model_state} = state.step_state
params = model_state.data

# Run inference with extracted params
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])
prediction = Axon.predict(model, params, %{"data" => input})

assert Nx.rank(prediction) == 2
assert Nx.shape(prediction) == {1, 1}
end

test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do
model =
Axon.input("data")
|> Axon.dense(2)
|> Axon.dense(1)

checkpoint_path = Path.join(tmp_path, "checkpoints_resume")

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
|> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed)

train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}]

# Run for 1 epoch
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1)

# Load checkpoint and resume
[ckpt_file] = File.ls!(checkpoint_path)
state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state()

resumed_loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
|> Axon.Loop.from_state(state)

# Resume - should complete without error
result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2)

assert %Axon.ModelState{} = result
end
end
end
Loading