Skip to content
Open
136 changes: 111 additions & 25 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ def _auto_preprocess_sample(
ValueError: If the tokenizer is missing/incompatible for chat-format datasets,
or if no recognized column is found.
"""
# Truthy ``sample.get`` checks instead of ``key in sample``: HF's schema
# unification fills missing values with ``None`` across heterogeneous JSONL
# rows, so a row that only has ``text`` would still expose ``prompt=None``
# in the unified schema. Falling through on null/empty lets such rows
# match the next column (e.g. ``text``) instead of crashing on
# ``"\n".join([None])``.
chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None)
if chat_key is not None:
if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"):
Expand All @@ -191,15 +197,15 @@ def _auto_preprocess_sample(
kwargs["tools"] = tools
return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs)

if "prompt" in sample:
if sample.get("prompt"):
parts = [sample["prompt"]]
parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k))
return "\n".join(parts)

if "text" in sample:
if sample.get("text"):
return sample["text"]

if "input" in sample:
if sample.get("input"):
parts = [sample["input"]]
if sample.get("output"):
parts.append(sample["output"])
Expand Down Expand Up @@ -231,6 +237,15 @@ def get_dataset_samples(
or a path to a ``.jsonl`` file. For local directory paths, the
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json``
builder and routed through the same auto-preprocess path as unregistered HF
datasets so chat / prompt / text columns are handled consistently with live
HF datasets. If that path fails on JSON parsing or PyArrow schema
unification, it falls back to a line-by-line reader that extracts the
legacy ``text`` field for backward compatibility. The fallback is also
used when the optional ``datasets`` package isn't installed, preserving
legacy plain-``.jsonl`` workflows in base installations. Local JSONL
files only expose the ``train`` split; passing any other ``split`` raises.
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
Expand All @@ -245,18 +260,43 @@ def get_dataset_samples(
Returns:
Samples: The list of samples.
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith(".jsonl"):
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
# Local JSONL: load via HF's ``json`` builder and route through the same
# auto-preprocess path as unregistered HF datasets so chat / prompt / text
# columns are handled consistently with a downloaded HF dataset. Never
# matches ``SUPPORTED_DATASET_CONFIG``.
is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name)

# HF's file-based builders only expose ``train`` for the ``data_files`` form
# we use, so any other split is a caller error. Surface it up front rather
# than letting ``load_dataset`` fail and silently dropping into the
# text-field fallback (which would ignore the requested split).
if is_jsonl and split is not None:
invalid = [s for s in _normalize_splits(split) if s != "train"]
if invalid:
raise ValueError(
f"Local JSONL files only expose the 'train' split, got {invalid}. "
"Either omit ``split`` or pass ``split='train'``."
)

from datasets import load_dataset
# Lazy ``datasets`` import: legacy ``.jsonl`` workflows historically didn't
# require the optional ``datasets`` extra, so keep them working with just
# the stdlib reader when the package isn't installed.
try:
from datasets import load_dataset
except ImportError:
if is_jsonl:
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
raise

local_dataset_path = None
if os.path.exists(dataset_name): # Local path
local_dataset_path = dataset_name
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))
if not is_jsonl:
# Directory paths may match a registered key via their basename
# (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail).
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))

is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
Expand Down Expand Up @@ -292,29 +332,73 @@ def _preprocess(sample: dict) -> str:
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": local_dataset_path or dataset_name}
if is_jsonl:
config = {"path": "json", "data_files": local_dataset_path}
else:
config = {"path": local_dataset_path or dataset_name}
# HF's file-based builders (incl. ``json``) label a string/list ``data_files``
# as the ``train`` split unconditionally — the filename on disk is ignored.
# Named splits require a dict ``data_files={"train": ..., "test": ...}``,
# which we don't expose here.
splits = _normalize_splits(split) if split is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

# Narrow the legacy fallback to JSON-parsing / Arrow schema failures. Any
# other error (split-not-found, IO, OOM, ...) should surface to the caller
# rather than be hidden by the text-field reader. Imported lazily because
# the exact module paths vary across versions; an empty tuple is a valid
# ``except`` target that catches nothing if neither is importable.
fallback_types: tuple[type[BaseException], ...] = ()
try:
from datasets.exceptions import DatasetGenerationError

fallback_types += (DatasetGenerationError,)
except ImportError:
pass
try:
from pyarrow.lib import ArrowInvalid

fallback_types += (ArrowInvalid,)
except ImportError:
pass

# load_dataset does not support a list of splits while streaming, so load each separately.
print(f"Loading dataset with {config=} and {splits=}")
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)
try:
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

return samples
samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
except fallback_types as e:
# Backward-compat fallback: legacy callers passed JSONL files whose only usable
# field is ``text``. If the HF ``json`` builder fails on schema inference or
# JSON parsing, fall back to a line-by-line reader that pulls ``text`` directly.
if not is_jsonl:
raise
assert local_dataset_path is not None # is_jsonl implies the path exists
try:
fallback_samples = get_jsonl_text_samples(local_dataset_path, num_samples, key="text")
except Exception:
# Fallback can't help either — surface the original HF error.
raise e from None
warn(
f"Failed to load {local_dataset_path} via the HF 'json' builder "
f"({type(e).__name__}: {e}); fell back to legacy text-field reader."
)
return fallback_samples


class _CustomDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -345,8 +429,10 @@ def get_dataset_dataloader(
"""Get a dataloader with the dataset name and tokenizer of the target model.

Args:
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list
mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the
resulting samples are concatenated before tokenization. ``num_samples`` may be
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
Expand Down
Loading
Loading