Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions examples/specdec_bench/specdec_bench/datasets/speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

# mypy: disable-error-code="index"
import math
import random
import re
from enum import Enum
Expand Down Expand Up @@ -737,10 +738,44 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data
}
table = table.replace_schema_metadata(new_meta or None)
dataset = HFDataset(table)
if self.num_samples is not None:
dataset = dataset.select(range(self.num_samples))
if self.num_samples is not None and self.num_samples < len(dataset):
dataset = self._stratified_select(dataset, self.num_samples)
return dataset

@staticmethod
def _stratified_select(dataset: "Dataset", n: int) -> "Dataset":
"""Select ``n`` samples uniformly across the ``category`` column.

When ``category`` is present, each category contributes
``ceil(n / num_categories)`` rows (capped by category size); the
result is truncated to exactly ``n`` rows by interleaving the
per-category samples round-robin so any further prefix slice
remains balanced. Falls back to ``range(n)`` when ``category`` is
absent. Indices come from ``range(category_size)`` (not random)
so behavior is deterministic.
"""
if "category" not in dataset.column_names:
return dataset.select(range(n))
cat_to_rows: dict[str, list[int]] = {}
for i, c in enumerate(dataset["category"]):
cat_to_rows.setdefault(c, []).append(i)
num_cats = len(cat_to_rows)
if num_cats <= 1:
return dataset.select(range(n))
per_cat = math.ceil(n / num_cats)
# Take the first ``per_cat`` rows from each category (parquet order
# within a category is treated as the canonical sample order).
cat_samples = [
rows[: min(per_cat, len(rows))] for rows in cat_to_rows.values()
]
# Round-robin interleave so the first N rows are balanced.
interleaved: list[int] = []
for i in range(per_cat):
for samples in cat_samples:
if i < len(samples):
interleaved.append(samples[i])
return dataset.select(interleaved[:n])
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

def _resolve_external_data(
self, dataset: "Dataset", speed_config: config_type | str
) -> "Dataset":
Expand Down
Loading