Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pprint
from functools import partial

from datasets import concatenate_datasets
from omegaconf import OmegaConf
from transformers import AutoTokenizer

Expand All @@ -29,6 +28,7 @@
load_response_dataset,
update_single_dataset_config,
)
from nemo_rl.data.utils import merge_datasets
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.utils.config import (
load_config,
Expand Down Expand Up @@ -89,7 +89,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
if hasattr(data, "preprocessor") and data.preprocessor is not None:
task_data_preprocessors[data.task_name] = data.preprocessor

merged_data = concatenate_datasets([data.dataset for data in data_list])
merged_data = merge_datasets([data.dataset for data in data_list])
dataset = AllTaskProcessedDataset(
merged_data,
tokenizer,
Expand Down Expand Up @@ -144,7 +144,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):

val_dataset = None
if len(val_data_list) > 0:
merged_val_data = concatenate_datasets(val_data_list)
merged_val_data = merge_datasets(val_data_list)
val_dataset = AllTaskProcessedDataset(
merged_val_data,
tokenizer,
Expand Down
22 changes: 19 additions & 3 deletions nemo_rl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Optional, Union

from datasets import concatenate_datasets
from datasets import Dataset, concatenate_datasets
from transformers import AutoProcessor, AutoTokenizer

from nemo_rl.data import DataConfig
Expand All @@ -25,11 +25,27 @@
load_response_dataset,
update_single_dataset_config,
)
from nemo_rl.data.datasets.response_datasets.oai_format_dataset import (
PreservingDataset,
)
from nemo_rl.data.processors import preference_preprocessor
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.environments.utils import create_env


def merge_datasets(datasets: list) -> Union[Dataset, "PreservingDataset"]:
"""Merge a list of datasets, handling both HuggingFace Dataset and PreservingDataset.

HuggingFace's ``concatenate_datasets`` does not accept ``PreservingDataset`` objects.
This helper detects the dataset types and merges them appropriately.
"""
if all(isinstance(d, PreservingDataset) for d in datasets):
merged_data = [item for d in datasets for item in d.data]
return PreservingDataset(merged_data)

return concatenate_datasets(datasets)


# TODO: @yukih: unify to setup_data after dataset refactored
def setup_response_data(
tokenizer: AutoProcessor | AutoTokenizer,
Expand Down Expand Up @@ -134,7 +150,7 @@ def setup_response_data(
}
else:
# merge datasets into a single dataset
merged_data = concatenate_datasets([data.dataset for data in data_list])
merged_data = merge_datasets([data.dataset for data in data_list])
dataset = AllTaskProcessedDataset(
merged_data,
tokenizer,
Expand Down Expand Up @@ -199,7 +215,7 @@ def setup_response_data(
# merge datasets
val_dataset = None
if len(val_data_list) > 0:
merged_val_data = concatenate_datasets(val_data_list)
merged_val_data = merge_datasets(val_data_list)
val_dataset = AllTaskProcessedDataset(
merged_val_data,
tokenizer,
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/data/datasets/test_preserving_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,59 @@ def test_comparison_with_standard_dataset(self):
preserving_dataset = PreservingDataset(data)
assert preserving_dataset[0]["tool_id"] == "123"
assert "tool_id" not in preserving_dataset[1] # Key doesn't exist


class TestMergeDatasets:
"""Test merge_datasets helper that handles both HF Dataset and PreservingDataset."""

def test_merge_preserving_datasets(self):
"""Test merging multiple PreservingDatasets."""
from nemo_rl.data.utils import merge_datasets

ds1 = PreservingDataset([{"a": 1}, {"b": 2}])
ds2 = PreservingDataset([{"c": 3}])

merged = merge_datasets([ds1, ds2])

assert isinstance(merged, PreservingDataset)
assert len(merged) == 3
assert merged[0] == {"a": 1}
assert merged[1] == {"b": 2}
assert merged[2] == {"c": 3}

def test_merge_hf_datasets(self):
"""Test merging standard HuggingFace Datasets still works."""
from nemo_rl.data.utils import merge_datasets

ds1 = Dataset.from_list([{"x": 1}, {"x": 2}])
ds2 = Dataset.from_list([{"x": 3}])

merged = merge_datasets([ds1, ds2])

assert isinstance(merged, Dataset)
assert len(merged) == 3
assert merged[0]["x"] == 1
assert merged[2]["x"] == 3

def test_merge_single_preserving_dataset(self):
"""Test merging a single PreservingDataset."""
from nemo_rl.data.utils import merge_datasets

ds = PreservingDataset([{"a": 1, "b": 2}, {"c": 3}])

merged = merge_datasets([ds])

assert isinstance(merged, PreservingDataset)
assert len(merged) == 2

def test_merge_preserving_datasets_preserves_heterogeneous_structure(self):
"""Test that merging PreservingDatasets doesn't introduce None-filling."""
from nemo_rl.data.utils import merge_datasets

ds1 = PreservingDataset([{"role": "user", "content": "hi", "tool_id": "1"}])
ds2 = PreservingDataset([{"role": "assistant", "content": "hello"}])

merged = merge_datasets([ds1, ds2])

assert "tool_id" in merged[0]
assert "tool_id" not in merged[1] # No None-filling
Loading