diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py index 05e3b344a..09bc0f6c5 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py @@ -475,7 +475,13 @@ async def _save( directory, ) - if await async_path.exists(directory): + exists_start = time.time() + dir_exists = await async_path.exists(directory) + jax.monitoring.record_event_duration_secs( + '/jax/orbax/write/async/foreground/check_dir_exists_secs', + time.time() - exists_start, + ) + if dir_exists: if force: if utils.is_primary_host(self._primary_host): logging.info( @@ -498,7 +504,13 @@ async def _save( ) ) else: + create_dir_start = time.time() await self.create_temporary_path(tmpdir) + jax.monitoring.record_event_duration_secs( + '/jax/orbax/write/async/foreground/create_dir_secs', + time.time() - create_dir_start, + ) + # Run copy ops. # Try to save using new CheckpointArgs API if supported by the handler. ckpt_args = checkpointer.construct_checkpoint_args( diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index ce6f90fb5..f74862211 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -776,6 +776,18 @@ async def async_save( total_serialization_initiated_time - batch_requests_ready_time, async_save_end_time - total_serialization_initiated_time, ) + jax.monitoring.record_event_duration_secs( + '/jax/orbax/write/async/foreground/batch_requests_ready_secs', + batch_requests_ready_time - start_time, + ) + jax.monitoring.record_event_duration_secs( + '/jax/orbax/write/async/foreground/d2h_transfer_secs', + total_serialization_initiated_time - batch_requests_ready_time, + ) + jax.monitoring.record_event_duration_secs( + '/jax/orbax/write/async/foreground/commit_write_metadata_prep_secs', + async_save_end_time - total_serialization_initiated_time, + ) return chained_futures def save(self, directory: epath.Path, *args, **kwargs):