Skip to content
Open
85 changes: 63 additions & 22 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ def get_dataset_samples(
or a path to a ``.jsonl`` file. For local directory paths, the
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json``
builder and routed through the same auto-preprocess path as unregistered HF
datasets so chat / prompt / text columns are handled consistently with live
HF datasets. If that path fails (e.g. PyArrow schema unification across
heterogeneous rows), it falls back to a line-by-line reader that extracts
the legacy ``text`` field for backward compatibility.
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
Expand All @@ -245,18 +251,23 @@ def get_dataset_samples(
Returns:
Samples: The list of samples.
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith(".jsonl"):
return get_jsonl_text_samples(dataset_name, num_samples, key="text")

from datasets import load_dataset

# Local JSONL: load via HF's ``json`` builder and route through the same
# auto-preprocess path as unregistered HF datasets so chat / prompt / text
# columns are handled consistently with a downloaded HF dataset. Never
# matches ``SUPPORTED_DATASET_CONFIG``.
is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name)

local_dataset_path = None
if os.path.exists(dataset_name): # Local path
local_dataset_path = dataset_name
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))
if not is_jsonl:
# Directory paths may match a registered key via their basename
# (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail).
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))

is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
Expand Down Expand Up @@ -292,29 +303,57 @@ def _preprocess(sample: dict) -> str:
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": local_dataset_path or dataset_name}
if is_jsonl:
config = {"path": "json", "data_files": local_dataset_path}
else:
config = {"path": local_dataset_path or dataset_name}
# HF's file-based builders (incl. ``json``) label a string/list ``data_files``
# as the ``train`` split unconditionally — the filename on disk is ignored.
# Named splits require a dict ``data_files={"train": ..., "test": ...}``,
# which we don't expose here.
splits = _normalize_splits(split) if split is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

# load_dataset does not support a list of splits while streaming, so load each separately.
print(f"Loading dataset with {config=} and {splits=}")
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)
try:
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)

return samples
samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
except Exception as e:
# Backward-compat fallback: legacy callers passed JSONL files whose only usable
# field is ``text``. If the HF ``json`` builder or auto-detect can't handle the
# file (schema inference error, unrecognized columns, etc.), fall back to a
# line-by-line reader that pulls the ``text`` field directly.
if is_jsonl:
assert local_dataset_path is not None # is_jsonl implies the path exists
try:
fallback_samples = get_jsonl_text_samples(
local_dataset_path, num_samples, key="text"
)
except Exception:
# Fallback can't help either — surface the original HF error.
raise e from None
warn(
f"Failed to load {local_dataset_path} via the HF 'json' builder "
f"({type(e).__name__}: {e}); fell back to legacy text-field reader."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
return fallback_samples
raise
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


class _CustomDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -345,8 +384,10 @@ def get_dataset_dataloader(
"""Get a dataloader with the dataset name and tokenizer of the target model.

Args:
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list
mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the
resulting samples are concatenated before tokenization. ``num_samples`` may be
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
Expand Down
Loading
Loading