Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@
model_type="gemma3-12b",
tuning_params={
"per_device_batch_size": 1,
"num_vocab_tiling": 16,
"num_batch_seq_tiling": 16,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "device",
Expand All @@ -1739,7 +1739,9 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"tokenizer_path": os.path.join(
"assets", "tokenizers", "tokenizer.gemma3"
),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand All @@ -1760,7 +1762,7 @@
model_type="gemma3-12b",
tuning_params={
"per_device_batch_size": 1,
"num_vocab_tiling": 16,
"num_batch_seq_tiling": 16,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"remat_policy": "custom",
Expand All @@ -1779,7 +1781,9 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"tokenizer_path": os.path.join(
"assets", "tokenizers", "tokenizer.gemma3"
),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand All @@ -1800,7 +1804,7 @@
model_type="gemma3-12b",
tuning_params={
"per_device_batch_size": 1,
"num_vocab_tiling": 16,
"num_batch_seq_tiling": 16,
"ici_fsdp_parallelism": 1,
"ici_fsdp_transpose_parallelism": -1,
"remat_policy": "custom",
Expand All @@ -1819,7 +1823,9 @@
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 2,
"tokenizer_path": os.path.join("assets", "tokenizers", "tokenizer.gemma3"),
"tokenizer_path": os.path.join(
"assets", "tokenizers", "tokenizer.gemma3"
),
"sa_block_q": 1024,
"sa_block_kv": 1024,
"sa_block_kv_compute": 1024,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/maxtext_v5p_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"remat_policy": "custom",
"context": "offload",
"mlpwo": "offload",
"num_vocab_tiling": 4,
"num_batch_seq_tiling": 4,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/core_concepts/tiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ The final output unembedding layer of a language model maps hidden states to log

Vocabulary tiling avoids materializing the full logits tensor. Instead, it tiles the input hidden states and computes the logits, loss, and gradients one tile at a time. Unlike GA, which is applied at the start of the model, vocabulary tiling is applied only to the input of the final layer.

In MaxText, the `num_vocab_tiling` configuration controls the number of tiles. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance.
In MaxText, the `num_batch_seq_tiling` configuration controls the number of tiles in batch and sequence axis. This technique is especially advantageous for models with large vocabularies (e.g., Gemma and Llama), particularly when training with long sequence lengths. One may also tile vocabulary dimension using `num_vocab_tiling` configuration. By preventing out-of-memory errors, vocabulary tiling can enable simpler sharding strategies (like FSDP) and unlock better computational performance.

![Illustration of vocabulary tiling.](../../_static/vocab_tiling.png)
![Illustration of batch_sequence tiling.](../../_static/vocab_tiling.png)
*Figure 2: Vocabulary tiling processes hidden states in tiles to avoid generating the full logits tensor.*

### Other Tiling Methods
Expand Down
3 changes: 2 additions & 1 deletion docs/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ MaxText is [available in PyPI](https://pypi.org/project/maxtext/) and can be ins
- [Optimized models tiering documentation](https://maxtext.readthedocs.io/en/latest/reference/models/tiering.html) has been refreshed.
- Added Versioning. Check out our [first set of release notes](https://maxtext.readthedocs.io/en/latest/release_notes.html)!
- Post-Training (SFT, RL) via [Tunix](https://github.com/google/tunix) is now available.
- Vocabulary tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_vocab_tiling` to unlock more efficient memory usage.
- Batch-Sequence tiling ([PR](https://github.com/AI-Hypercomputer/maxtext/pull/2242)) is now supported in MaxText! Adjust config `num_batch_seq_tiling` to unlock more efficient memory usage.
- Vocabulary tiling Additionally vocabulary dimension can also be tiled by adjusting `num_vocab_tiling`.
- The GPT-OSS family of models (20B, 120B) is now supported.

# Deprecations
Expand Down
8 changes: 5 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,12 @@ num_slices: -1

# Vocab Tiling Configs
# Enables a memory-saving optimization by computing the cross-entropy loss in chunks.
# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis,
# reducing peak memory usage. This is highly recommended for models with large
# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable.
# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis,
# and `num_batch_seq_tiling` parts along the batch-sequence axis, reducing peak memory usage.
# This is highly recommended for models with large vocabularies (e.g., Gemma).
# Set to a value greater than 1 to enable.
num_vocab_tiling: 1
num_batch_seq_tiling: 1

# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
Expand Down
41 changes: 37 additions & 4 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,34 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
)


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
def validate_batch_seq_tiling(
num_batch_seq_tiling: int,
per_device_batch_size: int,
max_target_length: int,
enable_nnx: bool,
):
if (per_device_batch_size * max_target_length) % num_batch_seq_tiling != 0:
raise ValueError(
"Per device batch size times sequence length should be divisible by the"
" number of batch seq tiles."
)
if (
num_batch_seq_tiling > 1 and enable_nnx
): # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError(
"We currently don't support batch seq tiling on NNX module."
)


def validate_vocab_tiling(
num_vocab_tiling: int,
vocab_size: int,
enable_nnx: bool,
):
if vocab_size % num_vocab_tiling != 0:
raise ValueError(
"vocab_size should be divisible by the number of vocab tiles."
)
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")

Expand Down Expand Up @@ -240,8 +265,16 @@ def validate_keys(keys):
validate_model_call_mode(keys["model_call_mode"])
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
validate_rope_type(keys["rope_type"])
validate_batch_seq_tiling(
keys["num_batch_seq_tiling"],
keys["per_device_batch_size"],
keys["max_target_length"],
keys["enable_nnx"],
)
validate_vocab_tiling(
keys["num_vocab_tiling"], keys["per_device_batch_size"], keys["max_target_length"], keys["enable_nnx"]
keys["num_vocab_tiling"],
keys["vocab_size"],
keys["enable_nnx"],
)
if keys["enable_rampup_batch_size"]:
validate_rampup_batch_size(
Expand Down
29 changes: 25 additions & 4 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,17 @@ class Tokenizer(BaseModel):
)
num_vocab_tiling: int = Field(
1,
description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.",
description=(
"Enables memory-saving optimization by tiling cross-entropy loss"
" computation along the vocabulary axis. >1 to enable."
),
)
num_batch_seq_tiling: int = Field(
1,
description=(
"Enables memory-saving optimization by tiling cross-entropy loss"
" computation along the batch-sequence axis. >1 to enable."
),
)


Expand Down Expand Up @@ -2503,12 +2513,23 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.quantization:
raise ValueError("Quantization is not supported with 'explicit' sharding.")
if self.vocab_size % self.num_vocab_tiling != 0:
raise ValueError(
"vocab_size should be divisible by the number of vocab tiles."
)
if (
self.per_device_batch_size > 0
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
and (self.per_device_batch_size * self.max_target_length)
% self.num_batch_seq_tiling
!= 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError(
"Per device batch size times sequence length should be divisible by"
" the number of batch tiles."
)
if (
self.num_vocab_tiling > 1 or self.num_batch_seq_tiling > 1
) and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,9 +1085,9 @@ def __call__(
# for efficiency, as the main model is frozen and the LM loss is not needed.
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
logits = None
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
elif cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow("intermediates", "hidden_states", hidden_state)

Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,9 +1057,9 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in):
if cfg.attention == "vllm_rpa":
logits = None

# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# When batch-sequence tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
if cfg.num_batch_seq_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
logits = None
self.sow(nnx.Intermediate, "hidden_states", hidden_state)

Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,8 @@ def __call__(
mutable=mutable_collections,
) # pytype: disable=wrong-keyword-args

# Materialize hidden state when vocab tiling is enabled
if self.config.num_vocab_tiling > 1:
# Materialize hidden state when batch-sequence tiling is enabled.
if self.config.num_batch_seq_tiling > 1:
self.hidden_states = hidden_state

# If we are initializing the model AND MTP is enabled, we must create
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
# The main model parameters are frozen and only the indexer is trained via KL divergence.
total_loss = 0.0
total_z_loss = 0.0
elif config.num_vocab_tiling > 1:
elif config.num_batch_seq_tiling > 1:
hidden_state_key = ("intermediates", "decoder", "hidden_states")
hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0]
total_loss, total_z_loss = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
Expand Down
18 changes: 14 additions & 4 deletions src/maxtext/utils/vocabulary_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,23 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat

def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation):
batch_size, seq_len, emb_dim = hidden_states.shape
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
batch_seq_tile_size = (batch_size * seq_len) // config.num_batch_seq_tiling

reshaped_hidden_states = _reshape(
hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec
hidden_states,
(config.num_batch_seq_tiling, batch_seq_tile_size, emb_dim),
reshaped_hidden_spec,
)
reshaped_labels = _reshape(
labels,
(config.num_batch_seq_tiling, batch_seq_tile_size),
reshaped_data_spec,
)
reshaped_segmentation = _reshape(
segmentation,
(config.num_batch_seq_tiling, batch_seq_tile_size),
reshaped_data_spec,
)
reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)
reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec)

# Scan body accumulates loss from each tile given chunked hidden states and labels
def _fwd_scan_body(accumulators, chunk_data):
Expand Down
26 changes: 13 additions & 13 deletions tests/unit/tiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def compute_loss_linen(intermediate_outputs, logits, data, config, model, params
"""
A loss function wrapper that deals with both vocab tiling or non-vocab tiling cases
"""
if config.num_vocab_tiling > 1:
if config.num_batch_seq_tiling > 1:
hidden_state_key = ("intermediates", "decoder", "hidden_states")
hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0]
total_loss, _ = vocab_tiling_linen_loss(hidden_states, data, config, model, params, is_train)
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_vocab_tiling_gradient_with_z_loss(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=1,
num_batch_seq_tiling=1,
z_loss_multiplier=1e-4, # Enable z-loss
)
quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_vocab_tiling_gradient_with_z_loss(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=4,
num_batch_seq_tiling=4,
z_loss_multiplier=1e-4, # Enable z-loss
)
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)
Expand Down Expand Up @@ -273,7 +273,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=1,
num_batch_seq_tiling=1,
)
quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling)
Expand Down Expand Up @@ -308,7 +308,7 @@ def test_vocab_tiling_gradient_non_tied_embedding(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=4,
num_batch_seq_tiling=4,
)
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)
# Loss correctness test
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_vocab_tiling_gradient_tied_embedding(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=1,
num_batch_seq_tiling=1,
)

quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
Expand Down Expand Up @@ -372,7 +372,7 @@ def test_vocab_tiling_gradient_tied_embedding(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=4,
num_batch_seq_tiling=4,
)
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)

Expand All @@ -399,7 +399,7 @@ def test_vocab_tiling_gradient_data_parallelism(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=1,
num_batch_seq_tiling=1,
)
quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling)
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_vocab_tiling_gradient_data_parallelism(self):
dtype="float32",
matmul_precision="high",
ici_data_parallelism=4,
num_vocab_tiling=4,
num_batch_seq_tiling=4,
)
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)
# Loss correctness test
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self):
base_num_decoder_layers=0,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=1,
num_batch_seq_tiling=1,
)
quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling)
Expand Down Expand Up @@ -499,7 +499,7 @@ def test_vocab_tiling_gradient_tensor_parallelism(self):
dtype="float32",
matmul_precision="high",
ici_tensor_parallelism=4,
num_vocab_tiling=4,
num_batch_seq_tiling=4,
)
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)
# Loss correctness test
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_vocab_tiling_gradient_context_parallelism(self):
packing=False,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=1,
num_batch_seq_tiling=1,
)
quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling)
devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling)
Expand Down Expand Up @@ -567,7 +567,7 @@ def test_vocab_tiling_gradient_context_parallelism(self):
packing=False,
dtype="float32",
matmul_precision="high",
num_vocab_tiling=4,
num_batch_seq_tiling=4,
)
loss_tiling, grads_tiling = self.get_grads(cfg_tiling, params, data)

Expand Down
Loading