Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Enforce the array shape and type check during Array restoration when
`ArrayRestoreArgs.strict` is set but shape/dtype is not provided.
- On platforms where `uvloop` is not supported, fallback to `nest_asyncio`.
- #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level
merging.

## [0.11.33] - 2026-02-17

Expand Down
81 changes: 43 additions & 38 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,27 +184,6 @@ class PyTreeOptions:

# TODO: Include an example of registering a custom LeafHandler.

Example:
To save certain leaves in float16, while others in float32, we can use
`create_array_storage_options_fn` like so::

import jax
import jax.numpy as jnp
from orbax.checkpoint.v1 import options as ocp_options

def create_opts_fn(keypath, value):
if 'small' in jax.tree_util.keystr(keypath):
return ocp_options.ArrayOptions.Saving.StorageOptions(
dtype=jnp.float16
)
return ocp_options.ArrayOptions.Saving.StorageOptions(dtype=jnp.float32)

pytree_options = ocp_options.PyTreeOptions(
saving=ocp_options.PyTreeOptions.Saving(
create_array_storage_options_fn=create_opts_fn
)
)

Attributes:
saving: Options for saving PyTrees.
loading: Options for loading PyTrees.
Expand All @@ -216,25 +195,9 @@ def create_opts_fn(keypath, value):
class Saving:
"""Options for saving PyTrees.

create_array_storage_options_fn:
A function that is called in order to create
:py:class:`.ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree,
when it is
being saved. It is called similar to:
`jax.tree.map_with_path(create_array_storage_options_fn, pytree_to_save)`.
If provided, it overrides any default settings in
:py:class:`.ArrayOptions.Saving.StorageOptions`.
pytree_metadata_options: Options for managing PyTree metadata.
"""

class CreateArrayStorageOptionsFn(Protocol):

def __call__(
self, key: tree_types.PyTreeKeyPath, value: Any
) -> ArrayOptions.Saving.StorageOptions:
...

create_array_storage_options_fn: CreateArrayStorageOptionsFn | None = None
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions)
)
Expand Down Expand Up @@ -265,7 +228,8 @@ class ArrayOptions:
names during initialization.

Example:
Configure array options with specific saving formats and loading behaviors::
To configure array options with specific saving formats and loading
behaviors we can do so like this::

from orbax.checkpoint.v1.options import ArrayOptions

Expand All @@ -280,6 +244,30 @@ class ArrayOptions:
)
)

To save certain leaves in float16, while others in float32, we can use
`scoped_storage_options_creator` like so::

import jax
import jax.numpy as jnp
from orbax.checkpoint.v1 import options as ocp_options

def create_opts_fn(keypath, value):
if 'small' in jax.tree_util.keystr(keypath):
return ocp_options.ArrayOptions.Saving.StorageOptions(
dtype=jnp.float16
)
return None # Fall back to global `storage_options`

array_options = ocp_options.ArrayOptions(
saving=ocp_options.ArrayOptions.Saving(
storage_options=ocp_options.ArrayOptions.Saving.StorageOptions(
dtype=jnp.float32
),
scoped_storage_options_creator=create_opts_fn
)

)

Attributes:
saving: Options for saving arrays.
loading: Options for loading arrays.
Expand Down Expand Up @@ -322,8 +310,24 @@ class Saving:
True.
array_metadata_store: Store to manage per host ArrayMetadata. To disable
ArrayMetadata persistence, set it to None.
storage_options: Global default for array storage options.
scoped_storage_options_creator: A function that, when dealing with
PyTrees, is applied to every leaf. If it returns an
:py:class:`ArrayOptions.Saving.StorageOptions`, its fields take
precedence when merging if they are set to non-None or non-default
values with respect to `storage_options`. If it returns `None`,
`storage_options` is used as a default for all fields. It is called
similar to: `jax.tree.map_with_path(scoped_storage_options_creator,
pytree_to_save)`.
"""

class ScopedStorageOptionsCreator(Protocol):

def __call__(
self, key: tree_types.PyTreeKeyPath, value: Any
) -> ArrayOptions.Saving.StorageOptions:
...

@dataclasses.dataclass(frozen=True, kw_only=True)
class StorageOptions:
"""Options used to customize array storage behavior for individual leaves.
Expand Down Expand Up @@ -367,6 +371,7 @@ class StorageOptions:
array_metadata_store: array_metadata_store_lib.Store | None = (
array_metadata_store_lib.Store()
)
scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None

@dataclasses.dataclass(frozen=True, kw_only=True)
class Loading:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
from orbax.checkpoint._src.serialization import types as v0_serialization_types
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import compatibility
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
from orbax.checkpoint.experimental.v1._src.serialization import registry
from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler
Expand Down Expand Up @@ -69,32 +71,19 @@ def _get_remaining_timeout(

def _get_v0_save_args(
checkpointable: PyTree,
array_storage_options: options_lib.ArrayOptions.Saving.StorageOptions,
create_array_storage_options_fn: (
options_lib.PyTreeOptions.Saving.CreateArrayStorageOptionsFn | None
),
array_saving_options: options_lib.ArrayOptions.Saving,
) -> PyTree:
"""Returns save args that are compatible with the V0 API."""

def _leaf_get_v0_save_args(k, v):
if create_array_storage_options_fn:
individual_array_storage_options = create_array_storage_options_fn(k, v)
save_dtype = (
np.dtype(individual_array_storage_options.dtype)
if individual_array_storage_options.dtype
else None
)
return v0_serialization_types.SaveArgs(
dtype=save_dtype,
chunk_byte_size=individual_array_storage_options.chunk_byte_size,
shard_axes=individual_array_storage_options.shard_axes,
)
return v0_serialization_types.SaveArgs(
dtype=np.dtype(array_storage_options.dtype)
if array_storage_options.dtype
resolved_options = options_resolution.resolve_storage_options(
k, v, array_saving_options
)
return type_handlers_v0.SaveArgs(
dtype=np.dtype(resolved_options.dtype)
if resolved_options.dtype is not None
else None,
chunk_byte_size=array_storage_options.chunk_byte_size,
shard_axes=array_storage_options.shard_axes,
chunk_byte_size=resolved_options.chunk_byte_size,
shard_axes=resolved_options.shard_axes,
)

return jax.tree.map_with_path(_leaf_get_v0_save_args, checkpointable)
Expand Down Expand Up @@ -135,8 +124,7 @@ def create_v0_save_args(
item=checkpointable,
save_args=_get_v0_save_args(
checkpointable,
context.array_options.saving.storage_options,
context.pytree_options.saving.create_array_storage_options_fn,
context.array_options.saving,
),
ocdbt_target_data_file_size=context.array_options.saving.ocdbt_target_data_file_size,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
import orbax.checkpoint.experimental.v1._src.context.options as options_lib
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils
from orbax.checkpoint.experimental.v1._src.serialization import registration
from orbax.checkpoint.experimental.v1._src.serialization import types
Expand Down Expand Up @@ -109,18 +111,18 @@ def _create_v0_saving_paraminfo(

def _create_v0_savearg(
param: ArraySerializationParam,
context: context_lib.Context,
array_saving_options: options_lib.ArrayOptions.Saving,
) -> type_handlers_v0.SaveArgs:
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
fn = context.pytree_options.saving.create_array_storage_options_fn
if fn:
storage_options = fn(param.keypath, param.value)
else:
storage_options = context.array_options.saving.storage_options
"""Creates a V0 `SaveArgs` from V1 params and array options for saving."""
resolved_options = options_resolution.resolve_storage_options(
param.keypath, param.value, array_saving_options
)
return type_handlers_v0.SaveArgs(
dtype=jnp.dtype(storage_options.dtype) if storage_options.dtype else None,
chunk_byte_size=storage_options.chunk_byte_size,
shard_axes=storage_options.shard_axes,
dtype=jnp.dtype(resolved_options.dtype)
if resolved_options.dtype is not None
else None,
chunk_byte_size=resolved_options.chunk_byte_size,
shard_axes=resolved_options.shard_axes,
)


Expand Down Expand Up @@ -223,7 +225,10 @@ async def serialize(
_create_v0_saving_paraminfo(p, self._context, serialization_context)
for p in params
]
saveargs = [_create_v0_savearg(p, self._context) for p in params]
saveargs = [
_create_v0_savearg(p, self._context.array_options.saving)
for p in params
]

commit_futures = await self._handler_impl.serialize(
values, paraminfos, saveargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
import orbax.checkpoint.experimental.v1._src.context.options as options_lib
from orbax.checkpoint.experimental.v1._src.serialization import options_resolution
from orbax.checkpoint.experimental.v1._src.serialization import registration
from orbax.checkpoint.experimental.v1._src.serialization import types

Expand Down Expand Up @@ -96,18 +98,18 @@ def _create_v0_saving_paraminfo(

def _create_v0_savearg(
param: NumpySerializationParam,
context: context_lib.Context,
array_saving_options: options_lib.ArrayOptions.Saving,
) -> type_handlers_v0.SaveArgs:
"""Creates a V0 `SaveArgs` from V1 params and context for saving."""
fn = context.pytree_options.saving.create_array_storage_options_fn
if fn:
storage_options = fn(param.keypath, param.value)
else:
storage_options = context.array_options.saving.storage_options
"""Creates a V0 `SaveArgs` from V1 params and array saving options."""
resolved_options = options_resolution.resolve_storage_options(
param.keypath, param.value, array_saving_options
)
return type_handlers_v0.SaveArgs(
dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None,
chunk_byte_size=storage_options.chunk_byte_size,
shard_axes=storage_options.shard_axes,
dtype=np.dtype(resolved_options.dtype)
if resolved_options.dtype is not None
else None,
chunk_byte_size=resolved_options.chunk_byte_size,
shard_axes=resolved_options.shard_axes,
)


Expand Down Expand Up @@ -188,7 +190,10 @@ async def serialize(
_create_v0_saving_paraminfo(p, self._context, serialization_context)
for p in params
]
saveargs = [_create_v0_savearg(p, self._context) for p in params]
saveargs = [
_create_v0_savearg(p, self._context.array_options.saving)
for p in params
]

commit_futures = await self._handler_impl.serialize(
values, paraminfos, saveargs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for serialization."""

from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


def resolve_storage_options(
keypath: tree_types.PyTreeKeyPath,
value: tree_types.LeafType,
array_saving_options: options_lib.ArrayOptions.Saving,
) -> options_lib.ArrayOptions.Saving.StorageOptions:
"""Resolves storage options using a global default and a per-leaf creator.

When dealing with PyTrees, `scoped_storage_options_creator` is applied to
every leaf. Its fields take precedence when merging if they are set to
non-None or non-default values with respect to the global `storage_options`.
If the creator returns `None`, the global `storage_options` is used for all
fields.

Args:
keypath: The PyTree keypath of the array being saved.
value: The PyTree leaf value (array) being saved.
array_saving_options: The Orbax array saving options to use for resolution.

Returns:
The resolved StorageOptions containing storage options.
"""
global_opts = array_saving_options.storage_options
if global_opts is None:
global_opts = options_lib.ArrayOptions.Saving.StorageOptions()

fn = array_saving_options.scoped_storage_options_creator
individual_opts = None
if fn is not None:
individual_opts = fn(keypath, value)

if individual_opts is not None:
resolved_dtype = (
individual_opts.dtype
if individual_opts.dtype is not None
else global_opts.dtype
)
resolved_chunk_byte_size = (
individual_opts.chunk_byte_size
if individual_opts.chunk_byte_size is not None
else global_opts.chunk_byte_size
)
resolved_shard_axes = (
individual_opts.shard_axes
if individual_opts.shard_axes
else global_opts.shard_axes
)
else:
resolved_dtype = global_opts.dtype
resolved_chunk_byte_size = global_opts.chunk_byte_size
resolved_shard_axes = global_opts.shard_axes

return options_lib.ArrayOptions.Saving.StorageOptions(
dtype=resolved_dtype,
chunk_byte_size=resolved_chunk_byte_size,
shard_axes=resolved_shard_axes,
)

Loading
Loading