Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- #v1 Add `use_load_and_broadcast` option.
- Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite.
- #v1 #safetensors Implement `load_pytree_async`, allowed only for the
Safetensors format at the moment.
- #v1 Add `DeletionOptions` to configure V1 Checkpointer's checkpoint deletion
behavior.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.logging import event_tracking
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import utils as path_utils



Expand All @@ -57,11 +57,6 @@ def _on_commit_callback(
checkpoint_start_time=checkpoint_start_time,
)
)
total_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/total_duration_secs',
total_duration_secs,
)


def _background_wait_for_commit_futures(
Expand Down Expand Up @@ -437,11 +432,14 @@ def _callback() -> None:
tmpdir,
checkpoint_start_time,
)
logging.info(
'Finished async_save (blocking + background). Time taken: %fs.'
' directory=%s',
time.time() - checkpoint_start_time,
operation_recorder = event_tracking.OperationRecorder(
tmpdir.get_final(),
operation_type=event_tracking.OperationType.SAVE,
async_origin=True,
primary_host=self._primary_host,
)
operation_recorder.record_completion(
time.time() - checkpoint_start_time
)
# Clean up all awaitable signals for the current operation id as they are
# no longer needed.
Expand All @@ -460,21 +458,6 @@ async def _save(
**kwargs,
):
directory = tmpdir.get_final()

if utils.is_primary_host(self._primary_host):
jax.monitoring.record_event(
'/jax/orbax/write/storage_type',
storage_type=path_utils.get_storage_type(directory),
)
# TODO(dicentra): Revise other metrics to also only report from the primary
# host where appropriate.
jax.monitoring.record_event('/jax/orbax/write/async/start')
logging.info(
'[process=%s] Started async saving checkpoint to %s.',
multihost.process_index(),
directory,
)

if await async_path.exists(directory):
if force:
if utils.is_primary_host(self._primary_host):
Expand Down Expand Up @@ -561,6 +544,13 @@ def save(
),
)
directory = epath.Path(directory)
operation_recorder = event_tracking.OperationRecorder(
directory,
operation_type=event_tracking.OperationType.SAVE,
async_origin=True,
primary_host=self._primary_host,
)
operation_recorder.record_start()
tmpdir = self.get_temporary_path(directory)
self.wait_until_finished()
self.synchronize_next_awaitable_signal_operation_id()
Expand All @@ -575,22 +565,14 @@ def save(
**kwargs,
)
)
operation_recorder.record_blocking_completion(
time.time() - checkpoint_start_time
)
self._async_manager.start_async_commit(
directory,
commit_futures=commit_ops,
on_commit_callback=on_commit_callback,
)
blocking_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/blocking_duration_secs',
blocking_duration_secs,
)
logging.info(
'Finished blocking save. Time taken: %fs. Continuing background save'
' to %s.',
blocking_duration_secs,
directory,
)

def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any:
"""See superclass documentation."""
Expand Down
56 changes: 18 additions & 38 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import gcs_utils
from orbax.checkpoint._src.path import step as step_lib
from orbax.checkpoint._src.path import utils as path_utils
from typing_extensions import Self # for Python version < 3.11


Expand Down Expand Up @@ -230,20 +229,13 @@ def save(
),
)
directory = epath.Path(directory)

if multihost.is_primary_host(self._primary_host):
jax.monitoring.record_event(
'/jax/orbax/write/storage_type',
storage_type=path_utils.get_storage_type(directory),
)
# TODO(dicentra): Revise other metrics to also only report from the primary
# host where appropriate.
jax.monitoring.record_event('/jax/orbax/write/start')
logging.info(
'[process=%s] Started saving checkpoint to %s.',
multihost.process_index(),
operation_recorder = event_tracking.OperationRecorder(
directory,
operation_type=event_tracking.OperationType.SAVE,
async_origin=False,
primary_host=self._primary_host,
)
operation_recorder.record_start()
self.synchronize_next_awaitable_signal_operation_id()

if directory.exists():
Expand Down Expand Up @@ -273,6 +265,9 @@ def save(
),
processes=self._active_processes,
)
operation_recorder.record_blocking_completion(
time.time() - checkpoint_start_time
)

# Ensure save operation atomicity and record time saved by checkpoint.
if multihost.is_primary_host(self._primary_host):
Expand All @@ -292,50 +287,35 @@ def save(
),
processes=self._active_processes,
)
save_duration_secs = time.time() - checkpoint_start_time
logging.info(
'Finished synchronous save in %.2f seconds to %s',
save_duration_secs,
directory,
)
operation_recorder.record_completion(time.time() - checkpoint_start_time)

def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any:
"""See superclass documentation."""
restore_start_time = time.time()
directory = epath.Path(directory)
operation_recorder = event_tracking.OperationRecorder(
directory,
operation_type=event_tracking.OperationType.LOAD,
async_origin=False,
primary_host=self._primary_host,
)
operation_recorder.record_start()
if not directory.exists():
raise FileNotFoundError(f'Checkpoint at {directory} not found.')
if not step_lib.is_path_finalized(directory):
raise ValueError(f'Found incomplete checkpoint at {directory}.')
logging.info('Restoring checkpoint from %s.', directory)
ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
restored = self._restore(directory, args=ckpt_args)

event_tracking.record_read_event(directory)

operation_recorder.record_blocking_completion(0.)
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:restore',
prefix=self._barrier_sync_key_prefix,
),
processes=self._active_processes,
)
restore_duration_secs = time.time() - restore_start_time
logging.info(
'Finished restoring checkpoint in %.2f seconds from %s.',
restore_duration_secs,
directory,
)

if multihost.is_primary_host(self._primary_host):
jax.monitoring.record_event(
'/jax/orbax/read/storage_type',
storage_type=path_utils.get_storage_type(directory),
)
jax.monitoring.record_event_duration_secs(
'/jax/orbax/read/total_duration_secs',
restore_duration_secs,
)
operation_recorder.record_completion(time.time() - restore_start_time)
return restored

def _restore(
Expand Down
160 changes: 122 additions & 38 deletions checkpoint/orbax/checkpoint/_src/logging/event_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Logging utilities for tracking checkpoint events."""

import enum

from absl import logging
from etils import epath
import jax
Expand Down Expand Up @@ -42,42 +44,124 @@ def record_delete_event(directory: epath.Path):
return None


def record_save_start(path: epath.Path, *, async_origin: bool):
"""Records the start of a save operation."""
logging.info(
'[process=%s] Started %s checkpoint to %s.',
multihost.process_index(),
'async saving' if async_origin else 'saving',
path,
)
if async_origin:
event_name = '/jax/orbax/write/async/start'
else:
event_name = '/jax/orbax/write/start'
jax.monitoring.record_event(event_name)
jax.monitoring.record_event(
'/jax/orbax/write/storage_type',
storage_type=path_utils.get_storage_type(path),
)


def record_save_completion(
path: epath.Path,
*,
total_duration_secs: float,
async_origin: bool,
):
"""Records the completion of a save operation."""
logging.info(
'Finished asynchronous save (blocking + background) in %.2f seconds'
' to %s',
total_duration_secs,
path,
)
# TODO(cpgaffney): No event is currently being recorded for synchronous saves.
# Consider collecting this information
if async_origin:
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/total_duration_secs',
total_duration_secs,
class OperationType(enum.Enum):
SAVE = 'save'
LOAD = 'load'


class OperationRecorder:
"""Records durations and events for checkpointing (save/load) operations."""

def __init__(
self,
path: epath.Path,
operation_type: OperationType,
*,
async_origin: bool,
primary_host: int = 0,
):
self._path = path
self._operation_type = operation_type
self._async_origin = async_origin
self._primary_host = primary_host

def record_start(self):
"""Records the start of an operation."""
logging.info(
'[process=%s] [%s] Started %s checkpoint @ %s.',
multihost.process_index(),
'async' if self._async_origin else 'sync',
self._operation_type.value,
self._path,
)

match self._operation_type:
case OperationType.SAVE:
event_name = (
'/jax/orbax/write/async/start'
if self._async_origin
else '/jax/orbax/write/start'
)
case OperationType.LOAD:
event_name = (
'/jax/orbax/read/async/start'
if self._async_origin
else '/jax/orbax/read/start'
)

if multihost.is_primary_host(self._primary_host):
jax.monitoring.record_event(event_name)

if self._operation_type == OperationType.SAVE:
jax.monitoring.record_event(
'/jax/orbax/write/storage_type',
storage_type=path_utils.get_storage_type(self._path),
)

def record_blocking_completion(self, duration_secs: float):
"""Records the completion of the blocking part of an operation."""
match self._operation_type:
case OperationType.SAVE:
event_name = (
'/jax/checkpoint/write/async/blocking_duration_secs'
if self._async_origin
else '/jax/orbax/write/blocking_duration_secs'
)
record_write_event(self._path)
case OperationType.LOAD:
event_name = (
'/jax/orbax/read/async/blocking_duration_secs'
if self._async_origin
else '/jax/orbax/read/blocking_duration_secs'
)
record_read_event(self._path)

if multihost.is_primary_host(self._primary_host):
jax.monitoring.record_event_duration_secs(
event_name,
duration_secs,
)

logging.info(
'[process=%s] [%s] Finished blocking %s in %.2f seconds. Continuing %s'
' @ %s.',
multihost.process_index(),
'async' if self._async_origin else 'sync',
self._operation_type.value,
duration_secs,
self._operation_type.value,
self._path,
)

def record_completion(self, duration_secs: float):
"""Records the completion of an entire operation."""
logging.info(
'[process=%s] [%s] Finished %s%s in %.2f seconds @ %s',
multihost.process_index(),
'async' if self._async_origin else 'sync',
self._operation_type.value,
' (blocking + background)' if self._async_origin else '',
duration_secs,
self._path,
)
match self._operation_type:
case OperationType.SAVE:
duration_event_name = (
'/jax/checkpoint/write/async/total_duration_secs'
if self._async_origin
else '/jax/orbax/write/total_duration_secs'
)
success_event_name = '/jax/orbax/write/success'
case OperationType.LOAD:
duration_event_name = (
'/jax/orbax/read/async/total_duration_secs'
if self._async_origin
else '/jax/orbax/read/total_duration_secs'
)
success_event_name = '/jax/orbax/read/success'
if multihost.is_primary_host(self._primary_host):
jax.monitoring.record_event(success_event_name)
jax.monitoring.record_event_duration_secs(
duration_event_name,
duration_secs,
)
Loading
Loading