diff --git a/benchmarks/maxtext_trillium_model_configs.py b/benchmarks/maxtext_trillium_model_configs.py index 4950c8f57b..35e9fbc2db 100644 --- a/benchmarks/maxtext_trillium_model_configs.py +++ b/benchmarks/maxtext_trillium_model_configs.py @@ -1721,7 +1721,7 @@ model_type="gemma3-12b", tuning_params={ "per_device_batch_size": 1, - "num_vocab_tiling": 16, + "num_batch_seq_tiling": 16, "ici_fsdp_parallelism": -1, "remat_policy": "custom", "decoder_layer_input": "device", @@ -1739,7 +1739,9 @@ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 2, - "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"), + "tokenizer_path": os.path.join( + "assets", "tokenizers", "tokenizer.gemma3" + ), "sa_block_q": 1024, "sa_block_kv": 1024, "sa_block_kv_compute": 1024, @@ -1760,7 +1762,7 @@ model_type="gemma3-12b", tuning_params={ "per_device_batch_size": 1, - "num_vocab_tiling": 16, + "num_batch_seq_tiling": 16, "ici_fsdp_parallelism": 1, "ici_fsdp_transpose_parallelism": -1, "remat_policy": "custom", @@ -1779,7 +1781,9 @@ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 2, - "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"), + "tokenizer_path": os.path.join( + "assets", "tokenizers", "tokenizer.gemma3" + ), "sa_block_q": 1024, "sa_block_kv": 1024, "sa_block_kv_compute": 1024, @@ -1800,7 +1804,7 @@ model_type="gemma3-12b", tuning_params={ "per_device_batch_size": 1, - "num_vocab_tiling": 16, + "num_batch_seq_tiling": 16, "ici_fsdp_parallelism": 1, "ici_fsdp_transpose_parallelism": -1, "remat_policy": "custom", @@ -1819,7 +1823,9 @@ "profiler": "xplane", "skip_first_n_steps_for_profiler": 10, "profiler_steps": 2, - "tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"), + "tokenizer_path": os.path.join( + "assets", "tokenizers", "tokenizer.gemma3" + ), "sa_block_q": 1024, "sa_block_kv": 1024, "sa_block_kv_compute": 1024, diff --git a/benchmarks/maxtext_v5p_model_configs.py b/benchmarks/maxtext_v5p_model_configs.py index f228b0f7fc..3df783022b 100644 --- a/benchmarks/maxtext_v5p_model_configs.py +++ b/benchmarks/maxtext_v5p_model_configs.py @@ -38,7 +38,7 @@ "remat_policy": "custom", "context": "offload", "mlpwo": "offload", - "num_vocab_tiling": 4, + "num_batch_seq_tiling": 4, "sa_block_q": 2048, "sa_block_kv": 2048, "sa_block_kv_compute": 2048, diff --git a/docs/reference/core_concepts/tiling.md b/docs/reference/core_concepts/tiling.md index 5d203e31aa..ef1efbb298 100644 --- a/docs/reference/core_concepts/tiling.md +++ b/docs/reference/core_concepts/tiling.md @@ -67,9 +67,9 @@ The final output unembedding layer of a language model maps hidden states to log Vocabulary tiling avoids materializing the full logits tensor. Instead, it tiles the input hidden states and computes the logits, loss, and gradients one tile at a time. Unlike GA, which is applied at the start of the model, vocabulary tiling is applied only to the input of the final layer. -In MaxText, the `num_vocab_tiling` configuration controls the number of tiles. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance. +In MaxText, the `num_batch_seq_tiling` configuration controls the number of tiles in batch and sequence axis. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. One may also tile vocabulary dimension using `num_vocab_tiling` configuration. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance. -![Illustration of vocabulary tiling.](../../_static/vocab_tiling.png) +![Illustration of batch_sequence tiling.](../../_static/vocab_tiling.png) *Figure 2: Vocabulary tiling processes hidden states in tiles to avoid generating the full logits tensor.* ### Other Tiling Methods diff --git a/docs/release_notes.md b/docs/release_notes.md index 7192da4d77..3835f67033 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -47,7 +47,8 @@ MaxText is [available in PyPI](https://pypi.org/project/maxtext/) and can be ins - [Optimized models tiering documentation](https://maxtext.readthedocs.io/en/latest/reference/models/tiering.html) has been refreshed. - Added Versioning. Check out our [first set of release notes](https://maxtext.readthedocs.io/en/latest/release_notes.html)! - Post-Training (SFT, RL) via [Tunix](https://github.com/google/tunix) is now available. -- Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage. +- Batch-Sequence tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_batch_seq_tiling` to unlock more efficient memory usage. +- Vocabulary tiling Additionally vocabulary dimension can also be tiled by adjusting `num_vocab_tiling`. - The GPT-OSS family of models (20B, 120B) is now supported. # Deprecations diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 3ff1c33153..83adff99f9 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -586,10 +586,12 @@ num_slices: -1 # Vocab Tiling Configs # Enables a memory-saving optimization by computing the cross-entropy loss in chunks. -# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis, -# reducing peak memory usage. This is highly recommended for models with large -# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable. +# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis, +# and `num_batch_seq_tiling` parts along the batch-sequence axis, reducing peak memory usage. +# This is highly recommended for models with large vocabularies (e.g., Gemma). +# Set to a value greater than 1 to enable. num_vocab_tiling: 1 +num_batch_seq_tiling: 1 # Tokenizer vocab_size: 32_000 # powers of 2 for sharding diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 888a23b199..ffc401c2df 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -194,9 +194,34 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - ) -def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): - if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: - raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") +def validate_batch_seq_tiling( + num_batch_seq_tiling: int, + per_device_batch_size: int, + max_target_length: int, + enable_nnx: bool, +): + if (per_device_batch_size * max_target_length) % num_batch_seq_tiling != 0: + raise ValueError( + "Per device batch size times sequence length should be divisible by the" + " number of batch seq tiles." + ) + if ( + num_batch_seq_tiling > 1 and enable_nnx + ): # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration + raise ValueError( + "We currently don't support batch seq tiling on NNX module." + ) + + +def validate_vocab_tiling( + num_vocab_tiling: int, + vocab_size: int, + enable_nnx: bool, +): + if vocab_size % num_vocab_tiling != 0: + raise ValueError( + "vocab_size should be divisible by the number of vocab tiles." + ) if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration raise ValueError("We currently don't support vocab tiling on NNX module.") @@ -240,8 +265,16 @@ def validate_keys(keys): validate_model_call_mode(keys["model_call_mode"]) validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"]) validate_rope_type(keys["rope_type"]) + validate_batch_seq_tiling( + keys["num_batch_seq_tiling"], + keys["per_device_batch_size"], + keys["max_target_length"], + keys["enable_nnx"], + ) validate_vocab_tiling( - keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"] + keys["num_vocab_tiling"], + keys["vocab_size"], + keys["enable_nnx"], ) if keys["enable_rampup_batch_size"]: validate_rampup_batch_size( diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 5c97ac2c1e..59c3199e02 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -983,7 +983,17 @@ class Tokenizer(BaseModel): ) num_vocab_tiling: int = Field( 1, - description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.", + description=( + "Enables memory-saving optimization by tiling cross-entropy loss" + " computation along the vocabulary axis. >1 to enable." + ), + ) + num_batch_seq_tiling: int = Field( + 1, + description=( + "Enables memory-saving optimization by tiling cross-entropy loss" + " computation along the batch-sequence axis. >1 to enable." + ), ) @@ -2503,12 +2513,23 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.quantization: raise ValueError("Quantization is not supported with 'explicit' sharding.") + if self.vocab_size % self.num_vocab_tiling != 0: + raise ValueError( + "vocab_size should be divisible by the number of vocab tiles." + ) if ( self.per_device_batch_size > 0 - and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 + and (self.per_device_batch_size * self.max_target_length) + % self.num_batch_seq_tiling + != 0 ): - raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: + raise ValueError( + "Per device batch size times sequence length should be divisible by" + " the number of batch tiles." + ) + if ( + self.num_vocab_tiling > 1 or self.num_batch_seq_tiling > 1 + ) and self.enable_nnx: raise ValueError("We currently don't support vocab tiling on NNX module.") if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 0ab392ecc2..f91c124235 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -1085,9 +1085,9 @@ def __call__( # for efficiency, as the main model is frozen and the LM loss is not needed. elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN: logits = None - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow("intermediates", "hidden_states", hidden_state) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..98658e83d0 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -1057,9 +1057,9 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): if cfg.attention == "vllm_rpa": logits = None - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory + # When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + if cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 0d1fcab700..c089f40b35 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -531,8 +531,8 @@ def __call__( mutable=mutable_collections, ) # pytype: disable=wrong-keyword-args - # Materialize hidden state when vocab tiling is enabled - if self.config.num_vocab_tiling > 1: + # Materialize hidden state when batch-sequence tiling is enabled. + if self.config.num_batch_seq_tiling > 1: self.hidden_states = hidden_state # If we are initializing the model AND MTP is enabled, we must create diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index fc973990ec..2abac19d52 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -141,7 +141,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): # The main model parameters are frozen and only the indexer is trained via KL divergence. total_loss = 0.0 total_z_loss = 0.0 - elif config.num_vocab_tiling > 1: + elif config.num_batch_seq_tiling > 1: hidden_state_key = ("intermediates", "decoder", "hidden_states") hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index ec68e9bc78..0838b54d86 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -117,13 +117,23 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation): batch_size, seq_len, emb_dim = hidden_states.shape - vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + batch_seq_tile_size = (batch_size * seq_len) // config.num_batch_seq_tiling reshaped_hidden_states = _reshape( - hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + hidden_states, + (config.num_batch_seq_tiling, batch_seq_tile_size, emb_dim), + reshaped_hidden_spec, + ) + reshaped_labels = _reshape( + labels, + (config.num_batch_seq_tiling, batch_seq_tile_size), + reshaped_data_spec, + ) + reshaped_segmentation = _reshape( + segmentation, + (config.num_batch_seq_tiling, batch_seq_tile_size), + reshaped_data_spec, ) - reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) - reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) # Scan body accumulates loss from each tile given chunked hidden states and labels def _fwd_scan_body(accumulators, chunk_data): diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 58b688634d..9e36f48138 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -42,7 +42,7 @@ def compute_loss_linen(intermediate_outputs, logits, data, config, model, params """ A loss function wrapper that deals with both vocab tiling or non-vocab tiling cases """ - if config.num_vocab_tiling > 1: + if config.num_batch_seq_tiling > 1: hidden_state_key = ("intermediates", "decoder", "hidden_states") hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] total_loss, _ = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train) @@ -206,7 +206,7 @@ def test_vocab_tiling_gradient_with_z_loss(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, z_loss_multiplier=1e-4, # Enable z-loss ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) @@ -242,7 +242,7 @@ def test_vocab_tiling_gradient_with_z_loss(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, z_loss_multiplier=1e-4, # Enable z-loss ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) @@ -273,7 +273,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -308,7 +308,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) # Loss correctness test @@ -337,7 +337,7 @@ def test_vocab_tiling_gradient_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) @@ -372,7 +372,7 @@ def test_vocab_tiling_gradient_tied_embedding(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) @@ -399,7 +399,7 @@ def test_vocab_tiling_gradient_data_parallelism(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -435,7 +435,7 @@ def test_vocab_tiling_gradient_data_parallelism(self): dtype="float32", matmul_precision="high", ici_data_parallelism=4, - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) # Loss correctness test @@ -463,7 +463,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self): base_num_decoder_layers=0, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -499,7 +499,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self): dtype="float32", matmul_precision="high", ici_tensor_parallelism=4, - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data) # Loss correctness test @@ -529,7 +529,7 @@ def test_vocab_tiling_gradient_context_parallelism(self): packing=False, dtype="float32", matmul_precision="high", - num_vocab_tiling=1, + num_batch_seq_tiling=1, ) quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -567,7 +567,7 @@ def test_vocab_tiling_gradient_context_parallelism(self): packing=False, dtype="float32", matmul_precision="high", - num_vocab_tiling=4, + num_batch_seq_tiling=4, ) loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)