Skip to content
Open
155 changes: 124 additions & 31 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,11 @@ def _auto_preprocess_sample(
ValueError: If the tokenizer is missing/incompatible for chat-format datasets,
or if no recognized column is found.
"""
chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None)

def _has_non_null_value(key: str) -> bool:
return sample.get(key) is not None

chat_key = next((k for k in ("messages", "conversations") if _has_non_null_value(k)), None)
if chat_key is not None:
if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"):
raise ValueError(
Expand All @@ -187,21 +191,23 @@ def _auto_preprocess_sample(
)
kwargs: dict[str, Any] = {}
tools = sample.get("tools")
if tools:
if tools is not None:
kwargs["tools"] = tools
return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs)

if "prompt" in sample:
if _has_non_null_value("prompt"):
parts = [sample["prompt"]]
parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k))
parts.extend(
sample[k] for k in ("completion", "response", "output") if _has_non_null_value(k)
)
return "\n".join(parts)

if "text" in sample:
if _has_non_null_value("text"):
return sample["text"]

if "input" in sample:
if _has_non_null_value("input"):
parts = [sample["input"]]
if sample.get("output"):
if _has_non_null_value("output"):
parts.append(sample["output"])
return "\n".join(parts)

Expand Down Expand Up @@ -231,6 +237,15 @@ def get_dataset_samples(
or a path to a ``.jsonl`` file. For local directory paths, the
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json``
builder and routed through the same auto-preprocess path as unregistered HF
datasets so chat / prompt / text columns are handled consistently with live
HF datasets. If that path fails on JSON parsing or PyArrow schema
unification, it falls back to a line-by-line reader that extracts the
legacy ``text`` field for backward compatibility. The fallback is also
used when the optional ``datasets`` package isn't installed, preserving
legacy plain-``.jsonl`` workflows in base installations. Local JSONL
files only expose the ``train`` split; passing any other ``split`` raises.
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
Expand All @@ -245,25 +260,53 @@ def get_dataset_samples(
Returns:
Samples: The list of samples.
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith(".jsonl"):
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
# Local JSONL: load via HF's ``json`` builder and route through the same
# auto-preprocess path as unregistered HF datasets so chat / prompt / text
# columns are handled consistently with a downloaded HF dataset. Never
# matches ``SUPPORTED_DATASET_CONFIG``.
is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name)
requested_splits = _normalize_splits(split) if split is not None else None
if requested_splits is not None and not requested_splits:
raise ValueError("``split`` must contain at least one split name.")

# HF's file-based builders only expose ``train`` for the ``data_files`` form
# we use, so any other split is a caller error. Surface it up front rather
# than letting ``load_dataset`` fail and silently dropping into the
# text-field fallback (which would ignore the requested split).
if is_jsonl and requested_splits is not None:
invalid = [s for s in requested_splits if s != "train"]
if invalid:
raise ValueError(
f"Local JSONL files only expose the 'train' split, got {invalid}. "
"Either omit ``split`` or pass ``split='train'``."
)

from datasets import load_dataset
# Lazy ``datasets`` import: legacy ``.jsonl`` workflows historically didn't
# require the optional ``datasets`` extra, so keep them working with just
# the stdlib reader when the package isn't installed.
try:
from datasets import load_dataset
except ImportError:
if is_jsonl:
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
raise

local_dataset_path = None
if os.path.exists(dataset_name): # Local path
local_dataset_path = dataset_name
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))
if not is_jsonl:
# Directory paths may match a registered key via their basename
# (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail).
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))

is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
config = dataset_config["config"].copy()
if local_dataset_path:
config["path"] = local_dataset_path
splits = _normalize_splits(split) if split is not None else config.pop("split", [None])
splits = requested_splits if requested_splits is not None else config.pop("split", [None])
if split is not None:
config.pop("split", None)

Expand Down Expand Up @@ -292,29 +335,77 @@ def _preprocess(sample: dict) -> str:
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": local_dataset_path or dataset_name}
splits = _normalize_splits(split) if split is not None else ["train"]
if is_jsonl:
config = {"path": "json", "data_files": local_dataset_path}
else:
config = {"path": local_dataset_path or dataset_name}
# HF's file-based builders (incl. ``json``) label a string/list ``data_files``
# as the ``train`` split unconditionally — the filename on disk is ignored.
# Named splits require a dict ``data_files={"train": ..., "test": ...}``,
# which we don't expose here.
splits = requested_splits if requested_splits is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

if not splits:
raise ValueError("``split`` must contain at least one split name.")

# Narrow the legacy fallback to JSON-parsing / Arrow schema failures. Any
# other error (split-not-found, IO, OOM, ...) should surface to the caller
# rather than be hidden by the text-field reader. Imported lazily because
# the exact module paths vary across versions; an empty tuple is a valid
# ``except`` target that catches nothing if neither is importable.
fallback_types: tuple[type[BaseException], ...] = ()
try:
from datasets.exceptions import DatasetGenerationError

fallback_types += (DatasetGenerationError,)
except ImportError:
pass
try:
from pyarrow.lib import ArrowInvalid

fallback_types += (ArrowInvalid,)
except ImportError:
pass

# load_dataset does not support a list of splits while streaming, so load each separately.
print(f"Loading dataset with {config=} and {splits=}")
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]
try:
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
except fallback_types as e:
# Backward-compat fallback: legacy callers passed JSONL files whose only usable
# field is ``text``. If the HF ``json`` builder fails on schema inference or
# JSON parsing, fall back to a line-by-line reader that pulls ``text`` directly.
if not is_jsonl:
raise
assert local_dataset_path is not None # is_jsonl implies the path exists
try:
fallback_samples = get_jsonl_text_samples(local_dataset_path, num_samples, key="text")
except Exception:
# Fallback can't help either — surface the original HF error.
raise e from None
safe_name = Path(local_dataset_path).name
warn(
f"Failed to load JSONL file '{safe_name}' via the HF 'json' builder "
f"({type(e).__name__}); fell back to legacy text-field reader."
)
return fallback_samples


class _CustomDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -345,8 +436,10 @@ def get_dataset_dataloader(
"""Get a dataloader with the dataset name and tokenizer of the target model.

Args:
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list
mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the
resulting samples are concatenated before tokenization. ``num_samples`` may be
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ skips = [
addopts = "-v -ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=50 --strict-markers"
pythonpath = ["tests/"]
markers = [
"integration: Tests that require external services or other non-hermetic dependencies",
"manual: Only run when --run-manual is given",
"network: Tests that require network access",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are many other tests that require network. Should this marker be added everywhere?

"release: Regression tests that should be run before every release",
]

Expand Down
Loading
Loading