diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 53a49f7c1..a47e9e1f7 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite. - #v1 Add `DeletionOptions` to configure V1 Checkpointer's checkpoint deletion behavior. +- #v1 Add `cleanup_tmp_directories` setting to V1 Checkpointer to manage +temporary directory cleanup behavior. ### Removed diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index 590107092..943720469 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -87,6 +87,7 @@ def __init__( path_step_lib.NameFormat[path_step_lib.Metadata] | None ) = None, custom_metadata: tree_types.JsonType | None = None, + cleanup_tmp_directories: bool = False, ): """Initializes a Checkpointer. @@ -150,6 +151,8 @@ def __init__( custom_metadata: A JSON dictionary representing user-specified custom metadata. This should be information that is relevant to the entire sequence of checkpoints, rather than to any single checkpoint. + cleanup_tmp_directories: If True, cleans up any existing temporary + directories on Checkpointer creation. """ context = context_lib.get_context() @@ -169,6 +172,7 @@ def __init__( save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, step_name_format=step_name_format, + cleanup_tmp_directories=cleanup_tmp_directories, max_to_keep=None, # Unlimited. todelete_full_path=context.deletion_options.gcs_deletion_options.todelete_full_path, async_options=context.async_options.v0(), diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 4edc518af..1b585c791 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -682,3 +682,25 @@ def test_gcs_deletion_options(self): checkpointer._manager._options.todelete_full_path, 'trash' ) + + @parameterized.named_parameters( + dict( + testcase_name='true', + cleanup_tmp_directories=True, + ), + dict( + testcase_name='false', + cleanup_tmp_directories=False, + ), + ) + def test_cleanup_tmp_directories( + self, cleanup_tmp_directories + ): + checkpointer = Checkpointer( + self.directory, cleanup_tmp_directories=cleanup_tmp_directories + ) + self.assertIs( + checkpointer._manager._options.cleanup_tmp_directories, + cleanup_tmp_directories, + ) + checkpointer.close()