Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/maxtext/trainers/tokenizer/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import jax
import grain.python as grain

from maxtext.input_pipeline import input_pipeline_utils
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
from maxtext.utils import gcs_utils

Expand All @@ -50,7 +51,7 @@
"grain_train_files", None, "File pattern for training data (local or gs://)", required=True
)
_GRAIN_FILE_TYPE = flags.DEFINE_string(
"grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord'."
"grain_file_type", "parquet", "Type of data files. Supported: 'parquet', 'arrayrecord', 'tfrecord'."
)
_DATA_COLUMN = flags.DEFINE_string("data_column", "text", "Column name to extract text from (used for arrayrecord).")
_VOCAB_SIZE = flags.DEFINE_integer("vocab_size", 32_768, "Vocab size")
Expand Down Expand Up @@ -82,27 +83,23 @@ def build_grain_iterator(data_file_pattern: str, data_file_type: str, data_keys:
dataset = grain.MapDataset.source(data_files)
dataset = dataset.map(grain.experimental.ParquetIterDataset)
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files))
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=list(data_keys)))
return iter(dataset)
elif data_file_type == "arrayrecord":
from maxtext.input_pipeline.protos import example_pb2 # pylint: disable=import-outside-toplevel

source = grain.ArrayRecordDataSource(data_files)
dataset = grain.MapDataset.source(source)

def _parse_example(raw_bytes):
example = example_pb2.Example()
example.ParseFromString(raw_bytes)
features = example.features.feature
parsed = {}
for col in data_keys:
if col in features:
parsed[col] = features[col].bytes_list.value[0]
return parsed

dataset = dataset.map(_parse_example)
dataset = dataset.map(input_pipeline_utils.ParseFeatures(list(data_keys), tokenize=True))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(list(data_keys), tokenize=True))
return iter(dataset)
elif data_file_type == "tfrecord":
dataset = grain.MapDataset.source(data_files)
dataset = dataset.map(input_pipeline_utils.make_tfrecord_iter_dataset)
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(data_files))
dataset = dataset.map(input_pipeline_utils.ParseFeatures(list(data_keys), tokenize=True))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(list(data_keys), tokenize=True))
return iter(dataset)
else:
raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet' or 'arrayrecord'.")
raise ValueError(f"Unsupported grain_file_type: {data_file_type!r}. Use 'parquet', 'arrayrecord', or 'tfrecord'.")


def _dump_chars_to_textfile(dataset_iter: Iterator, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]:
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/smoke/train_tokenizer_smoke_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Smoke tests for train_tokenizer file format support."""

import os
import unittest
import pytest

from maxtext.input_pipeline import input_pipeline_utils
from maxtext.trainers.tokenizer import train_tokenizer


class TrainTokenizerFormatTest(unittest.TestCase):
"""Smoke-tests that train_tokenizer runs end-to-end for each supported file format."""

def _run_format_test(self, file_pattern, file_type):
"""Uses a tiny corpus; the resulting tokenizer is not stored — only verify
it can be loaded and used for encode/decode.
"""
output_path = os.path.join("tests", f"test_tokenizer_{file_type}")
try:
dataset_iter = train_tokenizer.build_grain_iterator(file_pattern, file_type)
train_tokenizer.train_tokenizer(
dataset_iter,
vocab_path=output_path,
vocab_size=512,
max_corpus_chars=10_000,
)
tok = input_pipeline_utils.get_tokenizer(output_path, "sentencepiece", add_bos=False, add_eos=False)
text = "This is a test"
tokens = tok.encode(text)
self.assertGreater(len(tokens), 0)
self.assertEqual(tok.decode(tokens), text)
finally:
if os.path.exists(output_path):
os.remove(output_path)

@pytest.mark.cpu_only
def test_parquet(self):
self._run_format_test("gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", "parquet")

@pytest.mark.cpu_only
def test_arrayrecord(self):
self._run_format_test(
"gs://maxtext-dataset/array-record/c4/en/3.0.1/c4-train.array_record-00000-of-01024", "arrayrecord"
)

@pytest.mark.cpu_only
def test_tfrecord(self):
self._run_format_test("gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-00000-of-01024", "tfrecord")


if __name__ == "__main__":
unittest.main()
Loading