Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,30 @@

"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats."""

import asyncio
import collections
import json
from typing import Any, Awaitable, Sequence, cast
from typing import Any, Awaitable, Sequence

import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint._src.arrays import numpy_utils
from orbax.checkpoint._src.multihost import multihost as multihost_v0
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types
from orbax.checkpoint.experimental.v1._src.path import types
from orbax.checkpoint.experimental.v1._src.synchronization import multihost

CheckpointLayout = checkpoint_layout.CheckpointLayout
InvalidLayoutError = checkpoint_layout.InvalidLayoutError
Path = types.Path

HEADER_NUM_BYTES = 8
SAFETENSORS_SUFFIX = ".safetensors"
MAX_GAP_SIZE_BYTES = (
32 * 1024 * 1024
) # 32 MB gap allowed between tensors in a coalesced read block


def _get_dtypes() -> dict[str, Any]:
Expand Down Expand Up @@ -92,75 +97,6 @@ def _get_array_properties(info: dict[str, Any]) -> tuple[tuple[int, ...], Any]:
return shape, dtype


async def _read_non_contiguous_slice(
f: async_path.AsyncFile,
idx: tuple[slice, ...],
stored_shape: tuple[int, ...],
stored_dtype: np.dtype,
tensor_file_offset: int,
) -> np.ndarray:
"""Reads a slice of a tensor from a file.

This function solves the problem of reading a multi-dimensional slice from an
array where the slice's data is not stored as a single, contiguous block in
the file. It does so by recursively "walking" the dimensions of the slice.

Args:
f: The asynchronous file object (binary read mode)
idx: A tuple of slice objects representing the n-dimensional slice to
read.
stored_shape: The shape of the tensor.
stored_dtype: The `dtype` of the tensor.
tensor_file_offset: The starting byte offset of the tensor's data within
the file.

Returns:
The specific tensor slice.
"""
# Handle 0-d scalar case
if not idx:
await f.seek(tensor_file_offset)
num_bytes = np.dtype(stored_dtype).itemsize
scalar_bytes = await f.read(num_bytes)
# Reshape to () to create a 0-D NumPy array.
return np.frombuffer(scalar_bytes, dtype=stored_dtype).reshape(())

itemsize = np.dtype(stored_dtype).itemsize

# Calculate the byte strides for the full tensor. The stride for a
# dimension is the number of bytes to "jump" to get to the next element
# in that dimension while keeping all other indices the same.
global_strides = [itemsize] * len(stored_shape)
for i in range(len(stored_shape) - 2, -1, -1):
global_strides[i] = global_strides[i + 1] * stored_shape[i + 1]

async def _read_slice_recursively(dim: int, base_offset: int) -> bytes:
# TODO(b/438763866) - @zachmeyers to consider alternative methods.
s = idx[dim] # The slice for the current dimension.

# If we are at the last dimension, the data is contiguous.
if dim == len(stored_shape) - 1:
start = base_offset + s.start * global_strides[dim]
num_bytes = (s.stop - s.start) * itemsize
await f.seek(tensor_file_offset + start)
return cast(bytes, await f.read(num_bytes))

# For all other dimensions, iterate through the indices
# of the slice and make a recursive call for the next dimension.
chunks = []
for i in range(s.start, s.stop):
offset = base_offset + i * global_strides[dim]
chunk = await _read_slice_recursively(dim + 1, offset)
chunks.append(chunk)

return b"".join(chunks)

# Start the recursive reading process from the first dimension.
slice_bytes = await _read_slice_recursively(dim=0, base_offset=0)
shard_shape = numpy_utils.slice_shape(idx)
return np.frombuffer(slice_bytes, dtype=stored_dtype).reshape(shard_shape)


async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
"""Loads tensors from a safetensors file into host NumPy arrays."""
header, data_start_offset = await _read_safetensors_header(path)
Expand All @@ -179,65 +115,201 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]:
return tensors


def _create_non_sharded_array(
raw_data: memoryview | bytes,
abstract_leaf: Any,
stored_shape: tuple[int, ...],
stored_dtype: Any,
) -> jax.Array:
"""Creates a non-sharded JAX array from raw bytes."""
np_array = np.frombuffer(raw_data, dtype=stored_dtype).reshape(stored_shape)
target_dtype = abstract_leaf.dtype
if np_array.dtype != target_dtype:
np_array = np_array.astype(target_dtype)
return jax.device_put(np_array)


def _create_sharded_array(
raw_data: memoryview | bytes,
abstract_leaf: Any,
stored_shape: tuple[int, ...],
stored_dtype: Any,
num_hosts: int,
host_id: int,
flat_sharding: jax.sharding.NamedSharding,
) -> jax.Array:
"""Creates a sharded JAX array from raw bytes."""
sharding = abstract_leaf.sharding
target_dtype = abstract_leaf.dtype

# Use 1D flat contiguous read + reshard logic for maximum IO throughput.
total_elements = int(np.prod(stored_shape)) if stored_shape else 1

# Calculate padding
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
padded_elements = elements_per_host * num_hosts

start_idx = host_id * elements_per_host
end_idx = min((host_id + 1) * elements_per_host, total_elements)
num_elements_to_read = max(0, end_idx - start_idx)

local_data = np.frombuffer(raw_data, dtype=stored_dtype)
if local_data.dtype != target_dtype:
local_data = local_data.astype(target_dtype)

if num_elements_to_read < elements_per_host:
local_data = np.pad(
local_data, (0, elements_per_host - num_elements_to_read)
)

# Put local data on all addressable devices in the flat sharding
local_arrays = [
jax.device_put(local_data, d) for d in flat_sharding.addressable_devices
]

# Create the 1D sharded array
flat_array = jax.make_array_from_single_device_arrays(
(padded_elements,), flat_sharding, local_arrays
)

# Slice off the padding and reshape
if padded_elements > total_elements:
flat_array = flat_array[:total_elements]

reshaped_array = flat_array.reshape(stored_shape)

# Reshard to the target sharding
target_array = jax.device_put(reshaped_array, sharding)

return target_array


async def _load_non_sharded_array(
path: Path,
abstract_leaf: Any,
header_info: dict[str, Any],
data_start_offset: int,
) -> jax.Array:
"""Loads a single non-sharded array from a safetensors file."""
stored_shape, stored_dtype = _get_array_properties(header_info)
st_data_offsets = header_info["data_offsets"]

start_offset, end_offset = st_data_offsets
num_bytes = end_offset - start_offset
async with async_path.open_file(path, mode="rb") as f:
await f.seek(data_start_offset + start_offset)
tensor_bytes = await f.read(num_bytes)

return _create_non_sharded_array(
tensor_bytes, abstract_leaf, stored_shape, stored_dtype
)


async def _load_sharded_array(
path: Path,
abstract_leaf: Any,
header_info: dict[str, Any],
data_start_offset: int,
num_hosts: int,
host_id: int,
flat_sharding: jax.sharding.NamedSharding,
) -> jax.Array:
"""Loads a single sharded array from a safetensors file."""
stored_shape, stored_dtype = _get_array_properties(header_info)
st_data_offsets = header_info["data_offsets"]

total_elements = int(np.prod(stored_shape)) if stored_shape else 1
elements_per_host = (total_elements + num_hosts - 1) // num_hosts
start_idx = host_id * elements_per_host
end_idx = min((host_id + 1) * elements_per_host, total_elements)
num_elements_to_read = max(0, end_idx - start_idx)
itemsize = np.dtype(stored_dtype).itemsize

start_byte = st_data_offsets[0] + data_start_offset + start_idx * itemsize
num_bytes = num_elements_to_read * itemsize

async with async_path.open_file(path, mode="rb") as f:
await f.seek(start_byte)
raw_data = await f.read(num_bytes)

return _create_sharded_array(
raw_data,
abstract_leaf,
stored_shape,
stored_dtype,
num_hosts,
host_id,
flat_sharding,
)


async def _load_safetensors_on_device(
path: Path, abstract_pytree: dict[str, Any]
) -> dict[str, jax.Array]:
"""Loads tensors from a safetensors file into on-device JAX arrays."""
header, data_start_offset = await _read_safetensors_header(path)
restored_pytree = {}
async with async_path.open_file(path, mode="rb") as f:
for tensor_name, abstract_leaf in abstract_pytree.items():
if tensor_name not in header:
raise KeyError(
f"Tensor '{tensor_name}' not found in safetensors header of {path}."
)

stored_shape, stored_dtype = _get_array_properties(header[tensor_name])
st_data_offsets = header[tensor_name]["data_offsets"]
sharding = abstract_leaf.sharding
target_shape = abstract_leaf.shape
target_dtype = abstract_leaf.dtype

if sharding is None:
start_offset, end_offset = st_data_offsets
num_bytes = end_offset - start_offset
await f.seek(data_start_offset + start_offset)
tensor_bytes = await f.read(num_bytes)
np_array = np.frombuffer(tensor_bytes, dtype=stored_dtype).reshape(
stored_shape
)
if np_array.dtype != target_dtype:
np_array = np_array.astype(target_dtype)
restored_pytree[tensor_name] = jax.device_put(np_array)
continue

device_indices_map = sharding.addressable_devices_indices_map(
target_shape
num_hosts = multihost.process_count()
host_id = jax.process_index()

# Build an initial mesh grouping all global devices by host
devices_by_host = []
for i in range(num_hosts):
devices_by_host.append([
d
for d in jax.devices()
if multihost_v0.process_index_from_device(d) == i
])

# Ensure uniform mesh shape (in case of uneven device counts, which is rare)
num_devices_per_host = len(devices_by_host[0])
for d in devices_by_host:
if len(d) != num_devices_per_host:
raise ValueError("Number of devices must be the same across all hosts.")

initial_mesh = jax.sharding.Mesh(
np.array(devices_by_host), ("hosts", "devices")
)
flat_sharding = jax.sharding.NamedSharding(
initial_mesh, jax.sharding.PartitionSpec("hosts")
)

async def _load_tensor(
tensor_name: str, abstract_leaf: Any
) -> tuple[str, jax.Array]:
if abstract_leaf.sharding is None:
tensor = await _load_non_sharded_array(
path,
abstract_leaf,
header[tensor_name],
data_start_offset,
)
else:
# We have a target sharding.
tensor = await _load_sharded_array(
path,
abstract_leaf,
header[tensor_name],
data_start_offset,
num_hosts,
host_id,
flat_sharding,
)
return tensor_name, tensor

device_map = []
for device in device_indices_map:
idx = device_indices_map[device]
resolved_idx = numpy_utils.resolve_slice(idx, stored_shape)
shard_shape = numpy_utils.slice_shape(resolved_idx)

shard_np = await _read_non_contiguous_slice(
f,
resolved_idx,
stored_shape,
stored_dtype,
st_data_offsets[0] + data_start_offset,
)
shard_np = shard_np.reshape(shard_shape) # pytype: disable=attribute-error

if shard_np.dtype != target_dtype:
shard_np = shard_np.astype(target_dtype)
tasks = []
for tensor_name, abstract_leaf in abstract_pytree.items():
if tensor_name not in header:
raise KeyError(
f"Tensor '{tensor_name}' not found in safetensors header of {path}."
)
tasks.append(_load_tensor(tensor_name, abstract_leaf))

device_map.append(jax.device_put(shard_np, device))
results = await asyncio.gather(*tasks)
for tensor_name, tensor in results:
restored_pytree[tensor_name] = tensor

restored_pytree[tensor_name] = jax.make_array_from_single_device_arrays(
target_shape, sharding, device_map
)
return restored_pytree


Expand Down
Loading