From 19636889843851faff1df1caa47b72e59b0d2b1b Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 6 Apr 2026 20:41:23 -0700 Subject: [PATCH] Add elastic pause/resume functionality to MaxText. PiperOrigin-RevId: 895636178 --- src/maxtext/configs/base.yml | 6 + src/maxtext/configs/types.py | 23 +++ src/maxtext/trainers/pre_train/train.py | 27 ++- src/maxtext/utils/elastic_utils.py | 131 ++++++++++++ src/maxtext/utils/gcs_utils.py | 32 ++- tests/unit/elastic_utils_test.py | 252 ++++++++++++++++++++++++ 6 files changed, 469 insertions(+), 2 deletions(-) create mode 100644 src/maxtext/utils/elastic_utils.py create mode 100644 tests/unit/elastic_utils_test.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 2116ad77bc..06063c5fcf 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1194,3 +1194,9 @@ distill_temperature: 1.0 # 0.0 value disables this feature. distill_beta: 0.0 distill_layer_indices: None + +##### Elastic training parameters +# Elastic training is Pathways-specific and does not work on McJAX. +elastic_enabled: false +elastic_timeout_seconds: 300 +elastic_max_retries: 10 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index fcb7b0a789..3f56bf7406 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1551,6 +1551,26 @@ class Goodput(BaseModel): enable_gcp_step_deviation_metrics: bool = Field(True, description="Enable GCP step deviation metrics.") +class ElasticTraining(BaseModel): + """Configuration for elastic training and fault tolerance. + + Elastic training is Pathways-specific and does not work on McJAX. + """ + + elastic_enabled: bool = Field(False, description="Whether to enable elastic training.") + elastic_timeout_seconds: int = Field( + 300, + description=( + "The maximum number of seconds to wait for `elastic_minimum_slice_count` slices to become active. If this" + " timeout is reached during any retry attempt, a `TimeoutError` is raised and training fails." + ), + ) + elastic_max_retries: int = Field( + 10, + description="The maximum number of times to retry training when a slice failure occurs or when scaling up.", + ) + + class GcpMonitoring(BaseModel): """Configuration for GCP-specific workload monitoring.""" @@ -1948,6 +1968,7 @@ class MaxTextConfig( Checkpointing, OrbaxStorage, EmergencyCheckpointing, + ElasticTraining, # Data Types and Quantization DataTypes, Quantization, @@ -2457,6 +2478,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # H. RUN ALL CROSS-FIELD VALIDATIONS if self.load_parameters_path and self.load_full_state_path: raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.") + if self.elastic_enabled and not self.enable_single_controller: + raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).") if (self.load_parameters_path or self.load_full_state_path) and not self.enable_checkpointing: raise ValueError("You must set enable_checkpointing=True to load a checkpoint.") if self.enable_multi_tier_checkpointing: diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index a3c39acb9f..2c374ba651 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -41,6 +41,7 @@ from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils.globals import EPS +from maxtext.utils import elastic_utils # Placeholder: internal # pylint: disable=too-many-positional-arguments @@ -675,11 +676,35 @@ def run(config, recorder, diagnostic_config): train_loop(config, recorder) +def get_train_func(config, recorder, diagnostic_config, argv): + """Returns the train function, wrapping in elastic_retry if elastic training is enabled.""" + if config.elastic_enabled: + max_logging.log("Elastic utils: Elastic training enabled.") + + def elastic_train_wrapper(argv: Sequence[str]) -> None: + """Wrapper for elastic training initializes variables and runs the train loop.""" + elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv) + run( + elastic_config, + elastic_recorder, + elastic_diagnostic_config, + ) + + train_func = elastic_utils.elastic_retry(config)(functools.partial(elastic_train_wrapper, argv=argv)) + else: + # Use the already initialized variables + def train_func(): + run(config, recorder, diagnostic_config) + + return train_func + + def main(argv: Sequence[str]) -> None: config, recorder, diagnostic_config = initialize(argv) record_goodput(recorder, RECORD_JOB_START_TIME) + train_func = get_train_func(config, recorder, diagnostic_config, argv) with maybe_monitor_goodput(config): - run(config, recorder, diagnostic_config) + train_func() if __name__ == "__main__": diff --git a/src/maxtext/utils/elastic_utils.py b/src/maxtext/utils/elastic_utils.py new file mode 100644 index 0000000000..d8dd50aaec --- /dev/null +++ b/src/maxtext/utils/elastic_utils.py @@ -0,0 +1,131 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for Elastic Training.""" + +import functools +import jax +from maxtext.utils import gcs_utils +from maxtext.utils import max_logging +import pathwaysutils +from pathwaysutils.elastic import manager + + +elastic_manager: manager.Manager | None = None + + +def elastic_enabled(config) -> bool: + """Returns whether elastic mode is enabled.""" + return pathwaysutils.is_pathways_backend_used() and config.elastic_enabled + + +def clean_up_checkpoints(checkpoint_dir: str): + """Cleans up incomplete checkpoints after an elastic event.""" + max_logging.log("Elastic utils: Checking for incomplete checkpoint after an elastic event...") + checkpoint_dir = gcs_utils.add_trailing_slash(checkpoint_dir) + + # 1. List the "directories" (steps) + checkpoints = gcs_utils.gcs_list_directories(checkpoint_dir) + + # 2. Filter for directories that are numbers + checkpoints = [cp for cp in checkpoints if cp.isdigit()] + + if not checkpoints: + max_logging.log("Found no existing checkpoints. Continuing") + return + + # Sort naturally (numerical sort) and get the last one + checkpoints.sort(key=int) + latest_checkpoint_name = checkpoints[-1] + latest_checkpoint_path = f"{checkpoint_dir}{latest_checkpoint_name}/" + + max_logging.log(f"Checking latest checkpoint: {latest_checkpoint_path}") + + # 3. Check for commit_success file + success_markers = gcs_utils.gcs_glob_pattern(f"{latest_checkpoint_path}commit_success*") + + if not success_markers: + max_logging.log(f"No commit_success file found. Deleting {latest_checkpoint_path}...") + # TODO: Use Orbax 'Cancel Ongoing Checkpointing' API when available to + # prevent deleting a checkpoint that is currently being written. + gcs_utils.gcs_delete_directory(latest_checkpoint_path) + else: + max_logging.log(f"Found commit_success file. Keeping {latest_checkpoint_path}.") + + +def live_devices(): + """Returns the list of live devices.""" + global elastic_manager + # If pathways is not used or elastic_manager is not initialized, return all devices + if pathwaysutils.is_pathways_backend_used(): + if elastic_manager is None: + elastic_manager = manager.Manager() + # Filter devices that are in active slices + return [d for d in jax.devices() if d.slice_index in elastic_manager.active_slice_indices] + return jax.devices() + + +def chain_callbacks(*funcs): + """Helper function to chain callbacks.""" + + def wrapper(): + for func in funcs: + func() + + return wrapper + + +def elastic_retry(config, callback_fn=None): + """Decorator for elastic retry. + + If an elastic event occurs, the decorator will retry the decorated function + up to `config.elastic_max_retries` times. + Before each retry, it cleans up partial checkpoints by calling + `clean_up_checkpoints`. If `callback_fn` is provided, it is + called after `clean_up_checkpoints`. + + Args: + config: Config object. + callback_fn: Optional callback function to be called after + `clean_up_checkpoints` on an elastic event. + + Returns: + A decorator for elastic retry. + """ + global elastic_manager + if not elastic_enabled(config): + msg = ( + "Elastic training requires the Pathways backend, and elastic_enabled" + " must be set to True: current config.elastic_enabled:" + f" {config.elastic_enabled}, pathways backend used:" + f" {pathwaysutils.is_pathways_backend_used()}" + ) + raise ValueError(msg) + + max_logging.log("Elastic Retry Enabled") + if elastic_manager is None: + elastic_manager = manager.Manager() + + cleanup_partial = functools.partial(clean_up_checkpoints, config.checkpoint_dir) + + if callback_fn is None: + effective_callback = cleanup_partial + else: + effective_callback = chain_callbacks(cleanup_partial, callback_fn) + + return elastic_manager.elastic_retry( + max_retries=config.elastic_max_retries, + timeout=config.elastic_timeout_seconds, + on_elastic_event_callback=effective_callback, + ) diff --git a/src/maxtext/utils/gcs_utils.py b/src/maxtext/utils/gcs_utils.py index b0f8a98c01..ade7fb79ad 100644 --- a/src/maxtext/utils/gcs_utils.py +++ b/src/maxtext/utils/gcs_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common GCS Utils needed by multiple modules""" +"""Common GCS Utils needed by multiple modules""" import shutil import json import os @@ -20,6 +20,7 @@ from pathlib import Path from etils import epath import uuid +from concurrent.futures import ThreadPoolExecutor import yaml @@ -168,6 +169,35 @@ def gcs_list_directories(directory_path): return directories +def gcs_delete_directory(directory_path: str): + """Deletes a "directory" (all blobs with the prefix) from GCS. + + Args: + directory_path: The GCS path (gs://...) representing the "directory" to delete. + """ + if not _gcs_guard("gcs_delete_directory"): + return + storage_client = storage.Client() + bucket_name, directory_prefix = parse_gcs_bucket_and_prefix(directory_path) + bucket = storage_client.bucket(bucket_name) + + # Ensures the prefix has a trailing slash to avoid deleting more than intended. + if not directory_prefix.endswith("/"): + directory_prefix += "/" + + blobs = list(bucket.list_blobs(prefix=directory_prefix)) + if blobs: + # Uses a ThreadPoolExecutor to delete blobs in parallel to match gsutil -m performance. + def _delete_blob(blob): + try: + blob.delete() + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Error deleting blob {blob.name}: {e}") + + with ThreadPoolExecutor(max_workers=32) as executor: + executor.map(_delete_blob, blobs) + + def gcs_glob_pattern(pattern): """ Globs GCS files and returns a list of full GCS paths. diff --git a/tests/unit/elastic_utils_test.py b/tests/unit/elastic_utils_test.py new file mode 100644 index 0000000000..28a8998cb3 --- /dev/null +++ b/tests/unit/elastic_utils_test.py @@ -0,0 +1,252 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Elastic Training utility functions.""" + +import unittest +import pytest + +from maxtext.utils import elastic_utils +from maxtext.utils import gcs_utils +import pathwaysutils + + +class FakeGcsUtils: + """Fake implementation for gcs_utils functions.""" + + def __init__(self): + self.directories = {} + self.files = set() + self.deleted_directories = [] + + def gcs_list_directories(self, path): + if path in self.directories: + return self.directories[path] + if path.endswith("/") and path[:-1] in self.directories: + return self.directories[path[:-1]] + return [] + + def gcs_glob_pattern(self, pattern): + # Very simple glob implementation for testing + prefix = pattern.replace("*", "") + return [f for f in self.files if f.startswith(prefix)] + + def gcs_delete_directory(self, path): + self.deleted_directories.append(path) + + @staticmethod + def add_trailing_slash(path): + return gcs_utils.add_trailing_slash(path) + + @staticmethod + def parse_gcs_bucket_and_prefix(path): + return gcs_utils.parse_gcs_bucket_and_prefix(path) + + +class FakeManager: + """Fake implementation for pathwaysutils.elastic.manager.Manager.""" + + def __init__(self): + self.active_slice_indices = set() + self.elastic_retry_called = False + self.elastic_retry_kwargs = {} + + def elastic_retry(self, **kwargs): + self.elastic_retry_called = True + self.elastic_retry_kwargs = kwargs + + +class FakePathwaysUtils: + """Fake implementation for pathwaysutils.""" + + def __init__(self): + self.is_pathways_used = True + + def is_pathways_backend_used(self): + return self.is_pathways_used + + +class FakeLogging: + """Fake implementation for max_logging.""" + + def __init__(self): + self.logs = [] + + def log(self, message): + self.logs.append(message) + + +class FakeJax: + """Fake implementation for jax.""" + + def __init__(self): + self.devices_list = [] + + def devices(self, *args, **kwargs): + return self.devices_list + + +class FakeDevice: + """Fake Device object.""" + + def __init__(self, slice_index=0): + self.slice_index = slice_index + + +class FakeConfig: + """Fake configuration object.""" + + def __init__(self): + self.elastic_enabled = True + self.checkpoint_dir = "gs://test_bucket/checkpoints" + self.elastic_max_retries = 3 + self.elastic_timeout_seconds = 100 + + +@pytest.mark.cpu_only +class ElasticUtilsTest(unittest.TestCase): + """Unit tests for Elastic Training utility functions.""" + + def setUp(self): + super().setUp() + # Save original dependencies + self.original_pathwaysutils = elastic_utils.pathwaysutils + self.original_jax = elastic_utils.jax + self.original_gcs_utils = elastic_utils.gcs_utils + self.original_max_logging = elastic_utils.max_logging + self.original_manager_class = pathwaysutils.elastic.manager.Manager + + # Initialize fakes + self.fake_gcs_utils = FakeGcsUtils() + self.fake_pathwaysutils = FakePathwaysUtils() + self.fake_logging = FakeLogging() + self.fake_jax = FakeJax() + self.fake_manager = FakeManager() + + # Inject fakes into elastic_utils namespace + elastic_utils.pathwaysutils = self.fake_pathwaysutils + elastic_utils.jax = self.fake_jax + elastic_utils.gcs_utils = self.fake_gcs_utils + elastic_utils.max_logging = self.fake_logging + + # Hook up pathwaysutils.elastic.manager.Manager to return our fake_manager + pathwaysutils.elastic.manager.Manager = lambda *args, **kwargs: self.fake_manager + + # Reset global state for testing + elastic_utils.elastic_manager = None + + def tearDown(self): + # Restore original dependencies + elastic_utils.pathwaysutils = self.original_pathwaysutils + elastic_utils.jax = self.original_jax + elastic_utils.gcs_utils = self.original_gcs_utils + elastic_utils.max_logging = self.original_max_logging + pathwaysutils.elastic.manager.Manager = self.original_manager_class + super().tearDown() + + def test_elastic_enabled(self): + """Tests elastic_enabled.""" + config = FakeConfig() + self.fake_pathwaysutils.is_pathways_used = True + config.elastic_enabled = True + self.assertTrue(elastic_utils.elastic_enabled(config)) + + config.elastic_enabled = False + self.assertFalse(elastic_utils.elastic_enabled(config)) + + config.elastic_enabled = True + self.fake_pathwaysutils.is_pathways_used = False + self.assertFalse(elastic_utils.elastic_enabled(config)) + + def test_clean_up_checkpoints_no_checkpoints(self): + """Tests clean_up_checkpoints when no checkpoints exist.""" + self.fake_gcs_utils.directories = {"gs://test_bucket/checkpoints": []} + elastic_utils.clean_up_checkpoints("gs://test_bucket/checkpoints") + self.assertEqual(len(self.fake_gcs_utils.deleted_directories), 0) + + def test_clean_up_checkpoints_incomplete(self): + """Tests clean_up_checkpoints when the latest checkpoint is incomplete.""" + checkpoint_dir = "gs://test_bucket/checkpoints" + self.fake_gcs_utils.directories = {checkpoint_dir: ["1", "2", "10"]} + # No commit_success for "10" + elastic_utils.clean_up_checkpoints(checkpoint_dir) + self.assertIn(f"{checkpoint_dir}/10/", self.fake_gcs_utils.deleted_directories) + self.assertNotIn(f"{checkpoint_dir}/1/", self.fake_gcs_utils.deleted_directories) + self.assertNotIn(f"{checkpoint_dir}/2/", self.fake_gcs_utils.deleted_directories) + + def test_clean_up_checkpoints_complete(self): + """Tests clean_up_checkpoints when the latest checkpoint is complete.""" + checkpoint_dir = "gs://test_bucket/checkpoints" + self.fake_gcs_utils.directories = {checkpoint_dir: ["1", "2", "10"]} + self.fake_gcs_utils.files.add(f"{checkpoint_dir}/10/commit_success_0") + elastic_utils.clean_up_checkpoints(checkpoint_dir) + self.assertEqual(len(self.fake_gcs_utils.deleted_directories), 0) + + def test_live_devices_no_pathways(self): + """Tests live_devices when pathways is not used.""" + self.fake_pathwaysutils.is_pathways_used = False + device0 = FakeDevice(slice_index=0) + self.fake_jax.devices_list = [device0] + + devices = elastic_utils.live_devices() + self.assertEqual(devices, [device0]) + + def test_elastic_retry_disabled(self): + """Tests elastic_retry when disabled but pathways is used.""" + self.fake_pathwaysutils.is_pathways_used = True + config = FakeConfig() + config.elastic_enabled = False + msg = ( + "Elastic training requires the Pathways backend, and elastic_enabled" + " must be set to True: current config.elastic_enabled: False, pathways" + " backend used: True" + ) + with self.assertRaisesRegex(ValueError, msg): + elastic_utils.elastic_retry(config) + + def test_elastic_retry_no_pathways(self): + """Tests elastic_retry when enabled but pathways is not used.""" + self.fake_pathwaysutils.is_pathways_used = False + config = FakeConfig() + config.elastic_enabled = True + msg = ( + "Elastic training requires the Pathways backend, and elastic_enabled" + " must be set to True: current config.elastic_enabled: True, pathways" + " backend used: False" + ) + with self.assertRaisesRegex(ValueError, msg): + elastic_utils.elastic_retry(config) + + def test_chain_callbacks(self): + """Tests chain_callbacks.""" + # Test with no functions + chained_fn_empty = elastic_utils.chain_callbacks() + chained_fn_empty() # Should not fail + + # Test with multiple functions + call_order = [] + + def fn1(): + call_order.append(1) + + def fn2(): + call_order.append(2) + + chained_fn = elastic_utils.chain_callbacks(fn1, fn2) + chained_fn() + self.assertEqual(call_order, [1, 2]) + + +if __name__ == "__main__": + unittest.main()