diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index 5515945aa..17b19f054 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -14,18 +14,20 @@ """Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats.""" +import asyncio import collections import json -from typing import Any, Awaitable, Sequence, cast +from typing import Any, Awaitable, Sequence import jax import jax.numpy as jnp import numpy as np -from orbax.checkpoint._src.arrays import numpy_utils +from orbax.checkpoint._src.multihost import multihost as multihost_v0 from orbax.checkpoint._src.path import async_path from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types +from orbax.checkpoint.experimental.v1._src.synchronization import multihost CheckpointLayout = checkpoint_layout.CheckpointLayout InvalidLayoutError = checkpoint_layout.InvalidLayoutError @@ -33,6 +35,9 @@ HEADER_NUM_BYTES = 8 SAFETENSORS_SUFFIX = ".safetensors" +MAX_GAP_SIZE_BYTES = ( + 32 * 1024 * 1024 +) # 32 MB gap allowed between tensors in a coalesced read block def _get_dtypes() -> dict[str, Any]: @@ -92,75 +97,6 @@ def _get_array_properties(info: dict[str, Any]) -> tuple[tuple[int, ...], Any]: return shape, dtype -async def _read_non_contiguous_slice( - f: async_path.AsyncFile, - idx: tuple[slice, ...], - stored_shape: tuple[int, ...], - stored_dtype: np.dtype, - tensor_file_offset: int, -) -> np.ndarray: - """Reads a slice of a tensor from a file. - - This function solves the problem of reading a multi-dimensional slice from an - array where the slice's data is not stored as a single, contiguous block in - the file. It does so by recursively "walking" the dimensions of the slice. - - Args: - f: The asynchronous file object (binary read mode) - idx: A tuple of slice objects representing the n-dimensional slice to - read. - stored_shape: The shape of the tensor. - stored_dtype: The `dtype` of the tensor. - tensor_file_offset: The starting byte offset of the tensor's data within - the file. - - Returns: - The specific tensor slice. - """ - # Handle 0-d scalar case - if not idx: - await f.seek(tensor_file_offset) - num_bytes = np.dtype(stored_dtype).itemsize - scalar_bytes = await f.read(num_bytes) - # Reshape to () to create a 0-D NumPy array. - return np.frombuffer(scalar_bytes, dtype=stored_dtype).reshape(()) - - itemsize = np.dtype(stored_dtype).itemsize - - # Calculate the byte strides for the full tensor. The stride for a - # dimension is the number of bytes to "jump" to get to the next element - # in that dimension while keeping all other indices the same. - global_strides = [itemsize] * len(stored_shape) - for i in range(len(stored_shape) - 2, -1, -1): - global_strides[i] = global_strides[i + 1] * stored_shape[i + 1] - - async def _read_slice_recursively(dim: int, base_offset: int) -> bytes: - # TODO(b/438763866) - @zachmeyers to consider alternative methods. - s = idx[dim] # The slice for the current dimension. - - # If we are at the last dimension, the data is contiguous. - if dim == len(stored_shape) - 1: - start = base_offset + s.start * global_strides[dim] - num_bytes = (s.stop - s.start) * itemsize - await f.seek(tensor_file_offset + start) - return cast(bytes, await f.read(num_bytes)) - - # For all other dimensions, iterate through the indices - # of the slice and make a recursive call for the next dimension. - chunks = [] - for i in range(s.start, s.stop): - offset = base_offset + i * global_strides[dim] - chunk = await _read_slice_recursively(dim + 1, offset) - chunks.append(chunk) - - return b"".join(chunks) - - # Start the recursive reading process from the first dimension. - slice_bytes = await _read_slice_recursively(dim=0, base_offset=0) - shard_shape = numpy_utils.slice_shape(idx) - return np.frombuffer(slice_bytes, dtype=stored_dtype).reshape(shard_shape) - - async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]: """Loads tensors from a safetensors file into host NumPy arrays.""" header, data_start_offset = await _read_safetensors_header(path) @@ -179,65 +115,201 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]: return tensors +def _create_non_sharded_array( + raw_data: memoryview | bytes, + abstract_leaf: Any, + stored_shape: tuple[int, ...], + stored_dtype: Any, +) -> jax.Array: + """Creates a non-sharded JAX array from raw bytes.""" + np_array = np.frombuffer(raw_data, dtype=stored_dtype).reshape(stored_shape) + target_dtype = abstract_leaf.dtype + if np_array.dtype != target_dtype: + np_array = np_array.astype(target_dtype) + return jax.device_put(np_array) + + +def _create_sharded_array( + raw_data: memoryview | bytes, + abstract_leaf: Any, + stored_shape: tuple[int, ...], + stored_dtype: Any, + num_hosts: int, + host_id: int, + flat_sharding: jax.sharding.NamedSharding, +) -> jax.Array: + """Creates a sharded JAX array from raw bytes.""" + sharding = abstract_leaf.sharding + target_dtype = abstract_leaf.dtype + + # Use 1D flat contiguous read + reshard logic for maximum IO throughput. + total_elements = int(np.prod(stored_shape)) if stored_shape else 1 + + # Calculate padding + elements_per_host = (total_elements + num_hosts - 1) // num_hosts + padded_elements = elements_per_host * num_hosts + + start_idx = host_id * elements_per_host + end_idx = min((host_id + 1) * elements_per_host, total_elements) + num_elements_to_read = max(0, end_idx - start_idx) + + local_data = np.frombuffer(raw_data, dtype=stored_dtype) + if local_data.dtype != target_dtype: + local_data = local_data.astype(target_dtype) + + if num_elements_to_read < elements_per_host: + local_data = np.pad( + local_data, (0, elements_per_host - num_elements_to_read) + ) + + # Put local data on all addressable devices in the flat sharding + local_arrays = [ + jax.device_put(local_data, d) for d in flat_sharding.addressable_devices + ] + + # Create the 1D sharded array + flat_array = jax.make_array_from_single_device_arrays( + (padded_elements,), flat_sharding, local_arrays + ) + + # Slice off the padding and reshape + if padded_elements > total_elements: + flat_array = flat_array[:total_elements] + + reshaped_array = flat_array.reshape(stored_shape) + + # Reshard to the target sharding + target_array = jax.device_put(reshaped_array, sharding) + + return target_array + + +async def _load_non_sharded_array( + path: Path, + abstract_leaf: Any, + header_info: dict[str, Any], + data_start_offset: int, +) -> jax.Array: + """Loads a single non-sharded array from a safetensors file.""" + stored_shape, stored_dtype = _get_array_properties(header_info) + st_data_offsets = header_info["data_offsets"] + + start_offset, end_offset = st_data_offsets + num_bytes = end_offset - start_offset + async with async_path.open_file(path, mode="rb") as f: + await f.seek(data_start_offset + start_offset) + tensor_bytes = await f.read(num_bytes) + + return _create_non_sharded_array( + tensor_bytes, abstract_leaf, stored_shape, stored_dtype + ) + + +async def _load_sharded_array( + path: Path, + abstract_leaf: Any, + header_info: dict[str, Any], + data_start_offset: int, + num_hosts: int, + host_id: int, + flat_sharding: jax.sharding.NamedSharding, +) -> jax.Array: + """Loads a single sharded array from a safetensors file.""" + stored_shape, stored_dtype = _get_array_properties(header_info) + st_data_offsets = header_info["data_offsets"] + + total_elements = int(np.prod(stored_shape)) if stored_shape else 1 + elements_per_host = (total_elements + num_hosts - 1) // num_hosts + start_idx = host_id * elements_per_host + end_idx = min((host_id + 1) * elements_per_host, total_elements) + num_elements_to_read = max(0, end_idx - start_idx) + itemsize = np.dtype(stored_dtype).itemsize + + start_byte = st_data_offsets[0] + data_start_offset + start_idx * itemsize + num_bytes = num_elements_to_read * itemsize + + async with async_path.open_file(path, mode="rb") as f: + await f.seek(start_byte) + raw_data = await f.read(num_bytes) + + return _create_sharded_array( + raw_data, + abstract_leaf, + stored_shape, + stored_dtype, + num_hosts, + host_id, + flat_sharding, + ) + + async def _load_safetensors_on_device( path: Path, abstract_pytree: dict[str, Any] ) -> dict[str, jax.Array]: """Loads tensors from a safetensors file into on-device JAX arrays.""" header, data_start_offset = await _read_safetensors_header(path) restored_pytree = {} - async with async_path.open_file(path, mode="rb") as f: - for tensor_name, abstract_leaf in abstract_pytree.items(): - if tensor_name not in header: - raise KeyError( - f"Tensor '{tensor_name}' not found in safetensors header of {path}." - ) - stored_shape, stored_dtype = _get_array_properties(header[tensor_name]) - st_data_offsets = header[tensor_name]["data_offsets"] - sharding = abstract_leaf.sharding - target_shape = abstract_leaf.shape - target_dtype = abstract_leaf.dtype - - if sharding is None: - start_offset, end_offset = st_data_offsets - num_bytes = end_offset - start_offset - await f.seek(data_start_offset + start_offset) - tensor_bytes = await f.read(num_bytes) - np_array = np.frombuffer(tensor_bytes, dtype=stored_dtype).reshape( - stored_shape - ) - if np_array.dtype != target_dtype: - np_array = np_array.astype(target_dtype) - restored_pytree[tensor_name] = jax.device_put(np_array) - continue - - device_indices_map = sharding.addressable_devices_indices_map( - target_shape + num_hosts = multihost.process_count() + host_id = jax.process_index() + + # Build an initial mesh grouping all global devices by host + devices_by_host = [] + for i in range(num_hosts): + devices_by_host.append([ + d + for d in jax.devices() + if multihost_v0.process_index_from_device(d) == i + ]) + + # Ensure uniform mesh shape (in case of uneven device counts, which is rare) + num_devices_per_host = len(devices_by_host[0]) + for d in devices_by_host: + if len(d) != num_devices_per_host: + raise ValueError("Number of devices must be the same across all hosts.") + + initial_mesh = jax.sharding.Mesh( + np.array(devices_by_host), ("hosts", "devices") + ) + flat_sharding = jax.sharding.NamedSharding( + initial_mesh, jax.sharding.PartitionSpec("hosts") + ) + + async def _load_tensor( + tensor_name: str, abstract_leaf: Any + ) -> tuple[str, jax.Array]: + if abstract_leaf.sharding is None: + tensor = await _load_non_sharded_array( + path, + abstract_leaf, + header[tensor_name], + data_start_offset, ) + else: + # We have a target sharding. + tensor = await _load_sharded_array( + path, + abstract_leaf, + header[tensor_name], + data_start_offset, + num_hosts, + host_id, + flat_sharding, + ) + return tensor_name, tensor - device_map = [] - for device in device_indices_map: - idx = device_indices_map[device] - resolved_idx = numpy_utils.resolve_slice(idx, stored_shape) - shard_shape = numpy_utils.slice_shape(resolved_idx) - - shard_np = await _read_non_contiguous_slice( - f, - resolved_idx, - stored_shape, - stored_dtype, - st_data_offsets[0] + data_start_offset, - ) - shard_np = shard_np.reshape(shard_shape) # pytype: disable=attribute-error - - if shard_np.dtype != target_dtype: - shard_np = shard_np.astype(target_dtype) + tasks = [] + for tensor_name, abstract_leaf in abstract_pytree.items(): + if tensor_name not in header: + raise KeyError( + f"Tensor '{tensor_name}' not found in safetensors header of {path}." + ) + tasks.append(_load_tensor(tensor_name, abstract_leaf)) - device_map.append(jax.device_put(shard_np, device)) + results = await asyncio.gather(*tasks) + for tensor_name, tensor in results: + restored_pytree[tensor_name] = tensor - restored_pytree[tensor_name] = jax.make_array_from_single_device_arrays( - target_shape, sharding, device_map - ) return restored_pytree