From d6df8fb9f59a787d5a91decf176fdf86e6213578 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Thu, 2 Apr 2026 15:46:40 -0700 Subject: [PATCH] #v1 #safetensors Implement `load_pytree_async`, allowed only for the Safetensors format at the moment. PiperOrigin-RevId: 893742303 --- checkpoint/CHANGELOG.md | 2 + .../_src/checkpointers/async_checkpointer.py | 54 ++---- .../_src/checkpointers/checkpointer.py | 56 ++---- .../checkpoint/_src/logging/event_tracking.py | 160 +++++++++++++----- .../orbax/checkpoint/_src/path/atomicity.py | 1 - .../v1/_src/loading/layout_loading_test.py | 43 +++++ .../experimental/v1/_src/loading/loading.py | 155 +++++++++++++++-- .../experimental/v1/_src/saving/execution.py | 27 ++- 8 files changed, 359 insertions(+), 139 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 53a49f7c1..2a9cf743b 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - #v1 Add `use_load_and_broadcast` option. - Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite. +- #v1 #safetensors Implement `load_pytree_async`, allowed only for the +Safetensors format at the moment. - #v1 Add `DeletionOptions` to configure V1 Checkpointer's checkpoint deletion behavior. diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py index 05e3b344a..07d546d5d 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py @@ -31,12 +31,12 @@ from orbax.checkpoint._src.futures import future from orbax.checkpoint._src.futures import synchronization from orbax.checkpoint._src.handlers import async_checkpoint_handler +from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint._src.metadata import checkpoint from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.path import atomicity_types -from orbax.checkpoint._src.path import utils as path_utils @@ -57,11 +57,6 @@ def _on_commit_callback( checkpoint_start_time=checkpoint_start_time, ) ) - total_duration_secs = time.time() - checkpoint_start_time - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/total_duration_secs', - total_duration_secs, - ) def _background_wait_for_commit_futures( @@ -437,11 +432,14 @@ def _callback() -> None: tmpdir, checkpoint_start_time, ) - logging.info( - 'Finished async_save (blocking + background). Time taken: %fs.' - ' directory=%s', - time.time() - checkpoint_start_time, + operation_recorder = event_tracking.OperationRecorder( tmpdir.get_final(), + operation_type=event_tracking.OperationType.SAVE, + async_origin=True, + primary_host=self._primary_host, + ) + operation_recorder.record_completion( + time.time() - checkpoint_start_time ) # Clean up all awaitable signals for the current operation id as they are # no longer needed. @@ -460,21 +458,6 @@ async def _save( **kwargs, ): directory = tmpdir.get_final() - - if utils.is_primary_host(self._primary_host): - jax.monitoring.record_event( - '/jax/orbax/write/storage_type', - storage_type=path_utils.get_storage_type(directory), - ) - # TODO(dicentra): Revise other metrics to also only report from the primary - # host where appropriate. - jax.monitoring.record_event('/jax/orbax/write/async/start') - logging.info( - '[process=%s] Started async saving checkpoint to %s.', - multihost.process_index(), - directory, - ) - if await async_path.exists(directory): if force: if utils.is_primary_host(self._primary_host): @@ -561,6 +544,13 @@ def save( ), ) directory = epath.Path(directory) + operation_recorder = event_tracking.OperationRecorder( + directory, + operation_type=event_tracking.OperationType.SAVE, + async_origin=True, + primary_host=self._primary_host, + ) + operation_recorder.record_start() tmpdir = self.get_temporary_path(directory) self.wait_until_finished() self.synchronize_next_awaitable_signal_operation_id() @@ -575,22 +565,14 @@ def save( **kwargs, ) ) + operation_recorder.record_blocking_completion( + time.time() - checkpoint_start_time + ) self._async_manager.start_async_commit( directory, commit_futures=commit_ops, on_commit_callback=on_commit_callback, ) - blocking_duration_secs = time.time() - checkpoint_start_time - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/blocking_duration_secs', - blocking_duration_secs, - ) - logging.info( - 'Finished blocking save. Time taken: %fs. Continuing background save' - ' to %s.', - blocking_duration_secs, - directory, - ) def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: """See superclass documentation.""" diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py index 7d91bdda1..ea2147c1a 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py @@ -37,7 +37,6 @@ from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.path import gcs_utils from orbax.checkpoint._src.path import step as step_lib -from orbax.checkpoint._src.path import utils as path_utils from typing_extensions import Self # for Python version < 3.11 @@ -230,20 +229,13 @@ def save( ), ) directory = epath.Path(directory) - - if multihost.is_primary_host(self._primary_host): - jax.monitoring.record_event( - '/jax/orbax/write/storage_type', - storage_type=path_utils.get_storage_type(directory), - ) - # TODO(dicentra): Revise other metrics to also only report from the primary - # host where appropriate. - jax.monitoring.record_event('/jax/orbax/write/start') - logging.info( - '[process=%s] Started saving checkpoint to %s.', - multihost.process_index(), + operation_recorder = event_tracking.OperationRecorder( directory, + operation_type=event_tracking.OperationType.SAVE, + async_origin=False, + primary_host=self._primary_host, ) + operation_recorder.record_start() self.synchronize_next_awaitable_signal_operation_id() if directory.exists(): @@ -273,6 +265,9 @@ def save( ), processes=self._active_processes, ) + operation_recorder.record_blocking_completion( + time.time() - checkpoint_start_time + ) # Ensure save operation atomicity and record time saved by checkpoint. if multihost.is_primary_host(self._primary_host): @@ -292,17 +287,19 @@ def save( ), processes=self._active_processes, ) - save_duration_secs = time.time() - checkpoint_start_time - logging.info( - 'Finished synchronous save in %.2f seconds to %s', - save_duration_secs, - directory, - ) + operation_recorder.record_completion(time.time() - checkpoint_start_time) def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: """See superclass documentation.""" restore_start_time = time.time() directory = epath.Path(directory) + operation_recorder = event_tracking.OperationRecorder( + directory, + operation_type=event_tracking.OperationType.LOAD, + async_origin=False, + primary_host=self._primary_host, + ) + operation_recorder.record_start() if not directory.exists(): raise FileNotFoundError(f'Checkpoint at {directory} not found.') if not step_lib.is_path_finalized(directory): @@ -310,9 +307,7 @@ def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: logging.info('Restoring checkpoint from %s.', directory) ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs) restored = self._restore(directory, args=ckpt_args) - - event_tracking.record_read_event(directory) - + operation_recorder.record_blocking_completion(0.) multihost.sync_global_processes( multihost.unique_barrier_key( 'Checkpointer:restore', @@ -320,22 +315,7 @@ def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: ), processes=self._active_processes, ) - restore_duration_secs = time.time() - restore_start_time - logging.info( - 'Finished restoring checkpoint in %.2f seconds from %s.', - restore_duration_secs, - directory, - ) - - if multihost.is_primary_host(self._primary_host): - jax.monitoring.record_event( - '/jax/orbax/read/storage_type', - storage_type=path_utils.get_storage_type(directory), - ) - jax.monitoring.record_event_duration_secs( - '/jax/orbax/read/total_duration_secs', - restore_duration_secs, - ) + operation_recorder.record_completion(time.time() - restore_start_time) return restored def _restore( diff --git a/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py b/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py index dccb09c06..15ad5cb52 100644 --- a/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py +++ b/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py @@ -14,6 +14,8 @@ """Logging utilities for tracking checkpoint events.""" +import enum + from absl import logging from etils import epath import jax @@ -42,42 +44,124 @@ def record_delete_event(directory: epath.Path): return None -def record_save_start(path: epath.Path, *, async_origin: bool): - """Records the start of a save operation.""" - logging.info( - '[process=%s] Started %s checkpoint to %s.', - multihost.process_index(), - 'async saving' if async_origin else 'saving', - path, - ) - if async_origin: - event_name = '/jax/orbax/write/async/start' - else: - event_name = '/jax/orbax/write/start' - jax.monitoring.record_event(event_name) - jax.monitoring.record_event( - '/jax/orbax/write/storage_type', - storage_type=path_utils.get_storage_type(path), - ) - - -def record_save_completion( - path: epath.Path, - *, - total_duration_secs: float, - async_origin: bool, -): - """Records the completion of a save operation.""" - logging.info( - 'Finished asynchronous save (blocking + background) in %.2f seconds' - ' to %s', - total_duration_secs, - path, - ) - # TODO(cpgaffney): No event is currently being recorded for synchronous saves. - # Consider collecting this information - if async_origin: - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/total_duration_secs', - total_duration_secs, +class OperationType(enum.Enum): + SAVE = 'save' + LOAD = 'load' + + +class OperationRecorder: + """Records durations and events for checkpointing (save/load) operations.""" + + def __init__( + self, + path: epath.Path, + operation_type: OperationType, + *, + async_origin: bool, + primary_host: int = 0, + ): + self._path = path + self._operation_type = operation_type + self._async_origin = async_origin + self._primary_host = primary_host + + def record_start(self): + """Records the start of an operation.""" + logging.info( + '[process=%s] [%s] Started %s checkpoint @ %s.', + multihost.process_index(), + 'async' if self._async_origin else 'sync', + self._operation_type.value, + self._path, + ) + + match self._operation_type: + case OperationType.SAVE: + event_name = ( + '/jax/orbax/write/async/start' + if self._async_origin + else '/jax/orbax/write/start' + ) + case OperationType.LOAD: + event_name = ( + '/jax/orbax/read/async/start' + if self._async_origin + else '/jax/orbax/read/start' + ) + + if multihost.is_primary_host(self._primary_host): + jax.monitoring.record_event(event_name) + + if self._operation_type == OperationType.SAVE: + jax.monitoring.record_event( + '/jax/orbax/write/storage_type', + storage_type=path_utils.get_storage_type(self._path), + ) + + def record_blocking_completion(self, duration_secs: float): + """Records the completion of the blocking part of an operation.""" + match self._operation_type: + case OperationType.SAVE: + event_name = ( + '/jax/checkpoint/write/async/blocking_duration_secs' + if self._async_origin + else '/jax/orbax/write/blocking_duration_secs' + ) + record_write_event(self._path) + case OperationType.LOAD: + event_name = ( + '/jax/orbax/read/async/blocking_duration_secs' + if self._async_origin + else '/jax/orbax/read/blocking_duration_secs' + ) + record_read_event(self._path) + + if multihost.is_primary_host(self._primary_host): + jax.monitoring.record_event_duration_secs( + event_name, + duration_secs, + ) + + logging.info( + '[process=%s] [%s] Finished blocking %s in %.2f seconds. Continuing %s' + ' @ %s.', + multihost.process_index(), + 'async' if self._async_origin else 'sync', + self._operation_type.value, + duration_secs, + self._operation_type.value, + self._path, + ) + + def record_completion(self, duration_secs: float): + """Records the completion of an entire operation.""" + logging.info( + '[process=%s] [%s] Finished %s%s in %.2f seconds @ %s', + multihost.process_index(), + 'async' if self._async_origin else 'sync', + self._operation_type.value, + ' (blocking + background)' if self._async_origin else '', + duration_secs, + self._path, ) + match self._operation_type: + case OperationType.SAVE: + duration_event_name = ( + '/jax/checkpoint/write/async/total_duration_secs' + if self._async_origin + else '/jax/orbax/write/total_duration_secs' + ) + success_event_name = '/jax/orbax/write/success' + case OperationType.LOAD: + duration_event_name = ( + '/jax/orbax/read/async/total_duration_secs' + if self._async_origin + else '/jax/orbax/read/total_duration_secs' + ) + success_event_name = '/jax/orbax/read/success' + if multihost.is_primary_host(self._primary_host): + jax.monitoring.record_event(success_event_name) + jax.monitoring.record_event_duration_secs( + duration_event_name, + duration_secs, + ) diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity.py b/checkpoint/orbax/checkpoint/_src/path/atomicity.py index 7e3c8c061..d50b10eae 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity.py @@ -844,7 +844,6 @@ async def on_commit_callback( await tmp_dir.finalize( ) record_saved_duration(checkpoint_start_time) - jax.monitoring.record_event('/jax/orbax/write/success') logging.info( '[process=%s][thread=%s] Finished saving checkpoint (finalized tmp dir)' ' to `%s`.', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py index 1b9f81c46..a09426c10 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import time +from unittest import mock + from absl.testing import absltest from absl.testing import parameterized from etils import epath @@ -141,6 +145,45 @@ def test_load_checkpointables_with_checkpoint_metadata(self): ) test_utils.assert_tree_equal(self, self.checkpointables_to_save, loaded) + @parameterized.parameters( + (options_lib.CheckpointLayout.SAFETENSORS,), + (options_lib.CheckpointLayout.ORBAX,), + ) + def test_load_pytree_async(self, layout: options_lib.CheckpointLayout): + original_finalize_load = loading._LoadPyTreeResponse._finalize_load + + async def sleep_and_load(*args, **kwargs): + await asyncio.sleep(2) + return await original_finalize_load(*args, **kwargs) + + self.enter_context( + mock.patch.object( + loading._LoadPyTreeResponse, + '_finalize_load', + new=sleep_and_load, + ) + ) + + pytree = self.object_to_save + if layout == options_lib.CheckpointLayout.SAFETENSORS: + directory = self.safetensors_path + else: + directory = self.orbax_pytree_path + + with context_lib.Context(checkpoint_layout=layout): + if layout != options_lib.CheckpointLayout.SAFETENSORS: + with self.assertRaises(NotImplementedError): + loading.load_pytree_async(directory) + return + + start = time.time() + response = loading.load_pytree_async(directory) + + self.assertLess(time.time() - start, 1) + loaded = response.result() + self.assertGreater(time.time() - start, 2) + test_utils.assert_tree_equal(self, pytree, loaded) + # TODO(b/431045454): Add tests for abstract_checkpointables. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py index 972bbfc0e..75df2d660 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/loading.py @@ -13,6 +13,7 @@ # limitations under the License. """Defines free-function interface for loading.""" +from __future__ import annotations import functools import time @@ -22,6 +23,7 @@ from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.logging import event_tracking from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib import orbax.checkpoint.experimental.v1._src.handlers.global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout from orbax.checkpoint.experimental.v1._src.layout import registry as layout_registry @@ -29,6 +31,7 @@ from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.synchronization import multihost +from orbax.checkpoint.experimental.v1._src.synchronization import thread_utils from orbax.checkpoint.experimental.v1._src.synchronization import types as async_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types @@ -38,6 +41,8 @@ CheckpointMetadata = metadata_types.CheckpointMetadata PLACEHOLDER = ... +AsyncResponse = async_types.AsyncResponse + class LoadFn(Protocol): """Protocol for a two-phase load function used in `_load_impl`. @@ -135,7 +140,11 @@ def load_pytree( The restored `PyTree`. """ start_time = time.time() - logging.info('Loading checkpoint from %s.', path) + event_tracking.OperationRecorder( + path, + operation_type=event_tracking.OperationType.LOAD, + async_origin=False, + ).record_start() abstract_pytree = _standardize_abstract_checkpointables(abstract_pytree) validation.validate_pytree_checkpointable_name(checkpointable_name) @@ -264,7 +273,11 @@ def load_checkpointables( FileNotFoundError: If the checkpoint path does not exist. """ start_time = time.time() - logging.info('Loading checkpoint from %s.', path) + event_tracking.OperationRecorder( + path, + operation_type=event_tracking.OperationType.LOAD, + async_origin=False, + ).record_start() abstract_checkpointables = _standardize_abstract_checkpointables( abstract_checkpointables @@ -317,6 +330,11 @@ def _load_impl( async def _load() -> Any: load_awaitable = await load_fn() + event_tracking.OperationRecorder( + path, + operation_type=event_tracking.OperationType.LOAD, + async_origin=False, + ).record_blocking_completion(time.time() - start_time) result = await load_awaitable await multihost.sync_global_processes( multihost.unique_barrier_key( @@ -330,17 +348,98 @@ async def _load() -> Any: result = asyncio_utils.run_sync(_load()) - event_tracking.record_read_event(path) - duration_secs = time.time() - start_time - logging.info( - 'Finished loading checkpoint in %.2f seconds from %s.', - duration_secs, + event_tracking.OperationRecorder( path, - ) + operation_type=event_tracking.OperationType.LOAD, + async_origin=False, + ).record_completion(duration_secs) return result +class _LoadPyTreeResponse( + AsyncResponse[tree_types.PyTreeOf[tree_types.LeafType]] +): + """An :py:class:`.AsyncResponse` for :py:func:`.load_pytree_async`.""" + + def __init__( + self, + operation_id: str, + path: path_types.Path, + background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.LeafType]], + *, + start_time: float, + context: context_lib.Context, + ): + self._operation_id = operation_id + self._path = path + self._background_awaitable = background_awaitable + self._start_time = start_time + self._context = context + self._thread_runner = thread_utils.BackgroundThreadRunner[ + tree_types.PyTreeOf[tree_types.LeafType] + ](self._finalize_load()) + + @classmethod + def create( + cls, + background_awaitable: Awaitable[tree_types.PyTreeOf[tree_types.LeafType]], + path: path_types.Path, + start_time: float, + *, + context: context_lib.Context, + ) -> _LoadPyTreeResponse: + """Creates and returns the final AsyncResponse for a save operation.""" + blocking_duration_secs = time.time() - start_time + event_tracking.OperationRecorder( + path, + operation_type=event_tracking.OperationType.LOAD, + async_origin=True, + ).record_blocking_completion(blocking_duration_secs) + return cls( + context.operation_id(), + path, + background_awaitable, + start_time=start_time, + context=context, + ) + + async def _finalize_load(self) -> tree_types.PyTreeOf[tree_types.LeafType]: + logging.info( + '[process=%s] Waiting for background load operations', + multihost.process_index(), + ) + result = await self._background_awaitable + logging.vlog( + 1, + '[process=%s] Finished waiting for background load operations.', + multihost.process_index(), + ) + + await multihost.sync_global_processes( + multihost.unique_barrier_key( + '_load_async:finalize', + prefix=( + self._context.multiprocessing_options.barrier_sync_key_prefix + ), + ), + operation_id=self._context.operation_id(), + processes=self._context.multiprocessing_options.active_processes, + ) + total_duration_secs = time.time() - self._start_time + event_tracking.OperationRecorder( + self._path, + operation_type=event_tracking.OperationType.LOAD, + async_origin=True, + ).record_completion(total_duration_secs) + return result + + def result( + self, timeout: float | None = None + ) -> tree_types.PyTreeOf[tree_types.LeafType]: + return self._thread_runner.result(timeout=timeout) + + def load_pytree_async( path: path_types.PathLike, abstract_pytree: ( @@ -349,9 +448,43 @@ def load_pytree_async( *, checkpointable_name: str | None = PYTREE_CHECKPOINTABLE_KEY, ) -> async_types.AsyncResponse[tree_types.PyTreeOf[tree_types.LeafType]]: - """Loads a PyTree asynchronously. Not yet implemented.""" - del path, abstract_pytree, checkpointable_name - raise NotImplementedError('Asynchronous loading is not yet supported.') + """Loads a PyTree asynchronously. Currently has limited support.""" + start_time = time.time() + event_tracking.OperationRecorder( + path, + operation_type=event_tracking.OperationType.LOAD, + async_origin=True, + ).record_start() + ctx = context_lib.get_context() + if not path: + raise ValueError('Path must not be None.') + if ctx.checkpoint_layout != options_lib.CheckpointLayout.SAFETENSORS: + raise NotImplementedError( + 'Asynchronous loading only supported for SAFETENSORS checkpoint ' + f'layout, not {ctx.checkpoint_layout}.' + ) + path = ctx.file_options.path_class(path) + abstract_pytree = _standardize_abstract_checkpointables(abstract_pytree) + validation.validate_pytree_checkpointable_name(checkpointable_name) + + async def _blocking_load() -> Any: + layout = await layout_registry.get_checkpoint_layout_pytree( + path, ctx.checkpoint_layout, checkpointable_name + ) + return await layout.load_pytree( + path, + checkpointable_name=checkpointable_name, + abstract_pytree=abstract_pytree, + ) + + background_awaitable = asyncio_utils.run_sync(_blocking_load()) + response = _LoadPyTreeResponse.create( + background_awaitable, + path, + start_time=start_time, + context=ctx, + ) + return response def load_checkpointables_async( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index 617fa5a5f..086c277c5 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -124,15 +124,11 @@ def create( ) -> _SaveResponse: """Creates and returns the final AsyncResponse for a save operation.""" blocking_duration_secs = time.time() - start_time - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/blocking_duration_secs', - blocking_duration_secs, - ) - logging.info( - 'Finished blocking save in %.2f seconds. Continuing to write to %s.', - blocking_duration_secs, + event_tracking.OperationRecorder( temporary_path.temporary_path.get_final(), - ) + operation_type=event_tracking.OperationType.SAVE, + async_origin=async_origin, + ).record_blocking_completion(blocking_duration_secs) handler_typestrs = { name: handler_types.typestr(type(handler)) @@ -221,11 +217,11 @@ async def _finalize_save(self): processes=self._context.multiprocessing_options.active_processes, ) total_duration_secs = time.time() - self._start_time - event_tracking.record_save_completion( + event_tracking.OperationRecorder( self._temporary_path.temporary_path.get_final(), - total_duration_secs=total_duration_secs, + operation_type=event_tracking.OperationType.SAVE, async_origin=self._async_origin, - ) + ).record_completion(total_duration_secs) def result(self, timeout: float | None = None) -> None: return self._thread_runner.result(timeout=timeout) @@ -285,9 +281,6 @@ async def _run_blocking_save( path=temporary_path.path_awaiting_creation, checkpointables=checkpointables, ) - # Log write event for the final path. - event_tracking.record_write_event(temporary_path.temporary_path.get_final()) - return background_awaitable @@ -374,7 +367,11 @@ def save_checkpointables_impl( """See caller docstrings.""" validation.validate_abstract_checkpointables(checkpointables) start_time = time.time() - event_tracking.record_save_start(path, async_origin=async_origin) + event_tracking.OperationRecorder( + path, + operation_type=event_tracking.OperationType.SAVE, + async_origin=async_origin, + ).record_start() # Ensure the operation ID is incremented as soon as possible. This must be # done uniquely for each save operation. asyncio_utils.run_sync(context_lib.synchronize_next_operation_id())