Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1193,3 +1193,9 @@ distill_temperature: 1.0
# 0.0 value disables this feature.
distill_beta: 0.0
distill_layer_indices: None

##### Elastic training parameters
# Elastic training is Pathways-specific and does not work on McJAX.
elastic_enabled: false
elastic_timeout_seconds: 300
elastic_max_retries: 10
23 changes: 23 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,26 @@ class Goodput(BaseModel):
enable_gcp_step_deviation_metrics: bool = Field(True, description="Enable GCP step deviation metrics.")


class ElasticTraining(BaseModel):
"""Configuration for elastic training and fault tolerance.

Elastic training is Pathways-specific and does not work on McJAX.
"""

elastic_enabled: bool = Field(False, description="Whether to enable elastic training.")
elastic_timeout_seconds: int = Field(
300,
description=(
"The maximum number of seconds to wait for `elastic_minimum_slice_count` slices to become active. If this"
" timeout is reached during any retry attempt, a `TimeoutError` is raised and training fails."
),
)
elastic_max_retries: int = Field(
10,
description="The maximum number of times to retry training when a slice failure occurs or when scaling up.",
)


class GcpMonitoring(BaseModel):
"""Configuration for GCP-specific workload monitoring."""

Expand Down Expand Up @@ -1947,6 +1967,7 @@ class MaxTextConfig(
Checkpointing,
OrbaxStorage,
EmergencyCheckpointing,
ElasticTraining,
# Data Types and Quantization
DataTypes,
Quantization,
Expand Down Expand Up @@ -2456,6 +2477,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# H. RUN ALL CROSS-FIELD VALIDATIONS
if self.load_parameters_path and self.load_full_state_path:
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
if self.elastic_enabled and not self.enable_single_controller:
raise ValueError("Elastic training is only supported with Pathways (`enable_single_controller=True`).")
if (self.load_parameters_path or self.load_full_state_path) and not self.enable_checkpointing:
raise ValueError("You must set enable_checkpointing=True to load a checkpoint.")
if self.enable_multi_tier_checkpointing:
Expand Down
25 changes: 24 additions & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from maxtext.configs import pyconfig
from maxtext.common.common_types import ShardMode
from maxtext.utils.globals import EPS
from maxtext.utils import elastic_utils
# Placeholder: internal

# pylint: disable=too-many-positional-arguments
Expand Down Expand Up @@ -678,8 +679,30 @@ def run(config, recorder, diagnostic_config):
def main(argv: Sequence[str]) -> None:
config, recorder, diagnostic_config = initialize(argv)
record_goodput(recorder, RECORD_JOB_START_TIME)

if config.elastic_enabled:
max_logging.log("Elastic utils: Elastic training enabled.")

def elastic_train_func():
"""Train function for elastic training.

Initializes variables and runs the train loop.
"""
elastic_config, elastic_recorder, elastic_diagnostic_config = initialize(argv)
run(
elastic_config,
elastic_recorder,
elastic_diagnostic_config,
)

train_func = elastic_utils.elastic_retry(config)(elastic_train_func)
else:
# Use the already initialized variables
def train_func():
run(config, recorder, diagnostic_config)

with maybe_monitor_goodput(config):
run(config, recorder, diagnostic_config)
train_func()


if __name__ == "__main__":
Expand Down
129 changes: 129 additions & 0 deletions src/maxtext/utils/elastic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for Elastic Training."""

import functools
import jax
from maxtext.utils import gcs_utils
from maxtext.utils import max_logging
import pathwaysutils
from pathwaysutils.elastic import manager


elastic_manager: manager.Manager | None = None


def elastic_enabled(config) -> bool:
"""Returns whether elastic mode is enabled."""
return pathwaysutils.is_pathways_backend_used() and config.elastic_enabled


def clean_up_checkpoints(checkpoint_dir: str):
"""Cleans up incomplete checkpoints after an elastic event."""
max_logging.log("Elastic utils: Checking for incomplete checkpoint after an elastic event...")
checkpoint_dir = gcs_utils.add_trailing_slash(checkpoint_dir)

# 1. List the "directories" (steps)
checkpoints = gcs_utils.gcs_list_directories(checkpoint_dir)

# 2. Filter for directories that are numbers
checkpoints = [cp for cp in checkpoints if cp.isdigit()]

if not checkpoints:
max_logging.log("Found no existing checkpoints. Continuing")
return

# Sort naturally (numerical sort) and get the last one
checkpoints.sort(key=int)
latest_checkpoint_name = checkpoints[-1]
latest_checkpoint_path = f"{checkpoint_dir}{latest_checkpoint_name}/"

max_logging.log(f"Checking latest checkpoint: {latest_checkpoint_path}")

# 3. Check for commit_success file
success_markers = gcs_utils.gcs_glob_pattern(f"{latest_checkpoint_path}commit_success*")

if not success_markers:
max_logging.log(f"No commit_success file found. Deleting {latest_checkpoint_path}...")
gcs_utils.gcs_delete_directory(latest_checkpoint_path)
else:
max_logging.log(f"Found commit_success file. Keeping {latest_checkpoint_path}.")


def live_devices():
"""Returns the list of live devices."""
global elastic_manager
# If pathways is not used or elastic_manager is not initialized, return all devices
if pathwaysutils.is_pathways_backend_used():
if elastic_manager is None:
elastic_manager = manager.Manager()
# Filter devices that are in active slices
return [d for d in jax.devices() if d.slice_index in elastic_manager.active_slice_indices]
return jax.devices()


def chain_callbacks(*funcs):
"""Helper function to chain callbacks."""

def wrapper():
for func in funcs:
func()

return wrapper


def elastic_retry(config, callback_fn=None):
"""Decorator for elastic retry.

If an elastic event occurs, the decorator will retry the decorated function
up to `config.elastic_max_retries` times.
Before each retry, it cleans up partial checkpoints by calling
`clean_up_checkpoints`. If `callback_fn` is provided, it is
called after `clean_up_checkpoints`.

Args:
config: Config object.
callback_fn: Optional callback function to be called after
`clean_up_checkpoints` on an elastic event.

Returns:
A decorator for elastic retry.
"""
global elastic_manager
if not elastic_enabled(config):
msg = (
"Elastic training requires the Pathways backend, and elastic_enabled"
" must be set to True: current config.elastic_enabled:"
f" {config.elastic_enabled}, pathways backend used:"
f" {pathwaysutils.is_pathways_backend_used()}"
)
raise ValueError(msg)

max_logging.log("Elastic Retry Enabled")
if elastic_manager is None:
elastic_manager = manager.Manager()

cleanup_partial = functools.partial(clean_up_checkpoints, config.checkpoint_dir)

if callback_fn is None:
effective_callback = cleanup_partial
else:
effective_callback = chain_callbacks(cleanup_partial, callback_fn)

return elastic_manager.elastic_retry(
max_retries=config.elastic_max_retries,
timeout=config.elastic_timeout_seconds,
on_elastic_event_callback=effective_callback,
)
32 changes: 31 additions & 1 deletion src/maxtext/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" Common GCS Utils needed by multiple modules"""
"""Common GCS Utils needed by multiple modules"""
import shutil
import json
import os
import socket
from pathlib import Path
from etils import epath
import uuid
from concurrent.futures import ThreadPoolExecutor

import yaml

Expand Down Expand Up @@ -168,6 +169,35 @@ def gcs_list_directories(directory_path):
return directories


def gcs_delete_directory(directory_path: str):
"""Deletes a "directory" (all blobs with the prefix) from GCS.

Args:
directory_path: The GCS path (gs://...) representing the "directory" to delete.
"""
if not _gcs_guard("gcs_delete_directory"):
return
storage_client = storage.Client()
bucket_name, directory_prefix = parse_gcs_bucket_and_prefix(directory_path)
bucket = storage_client.bucket(bucket_name)

# Ensures the prefix has a trailing slash to avoid deleting more than intended.
if not directory_prefix.endswith("/"):
directory_prefix += "/"

blobs = list(bucket.list_blobs(prefix=directory_prefix))
if blobs:
# Uses a ThreadPoolExecutor to delete blobs in parallel to match gsutil -m performance.
def _delete_blob(blob):
try:
blob.delete()
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Error deleting blob {blob.name}: {e}")

with ThreadPoolExecutor(max_workers=32) as executor:
executor.map(_delete_blob, blobs)


def gcs_glob_pattern(pattern):
"""
Globs GCS files and returns a list of full GCS paths.
Expand Down
Loading
Loading