diff --git a/examples/specdec_bench/specdec_bench/datasets/speed.py b/examples/specdec_bench/specdec_bench/datasets/speed.py index fe544bb353e..3552d71a1ad 100644 --- a/examples/specdec_bench/specdec_bench/datasets/speed.py +++ b/examples/specdec_bench/specdec_bench/datasets/speed.py @@ -737,10 +737,40 @@ def _load_dataset(self, config_name_or_dataset_path: config_type | str) -> "Data } table = table.replace_schema_metadata(new_meta or None) dataset = HFDataset(table) - if self.num_samples is not None: - dataset = dataset.select(range(self.num_samples)) + if self.num_samples is not None and self.num_samples < len(dataset): + dataset = self._stratified_select(dataset, self.num_samples) return dataset + @staticmethod + def _stratified_select(dataset: "Dataset", n: int) -> "Dataset": + """Select ``n`` samples uniformly across the ``category`` column. + + Round-robin across categories until ``n`` rows are collected. The + resulting prefix is balanced; once a smaller category is exhausted + the remaining categories continue contributing, so exactly ``n`` + rows are returned whenever ``n`` does not exceed the dataset size. + Falls back to ``range(n)`` when ``category`` is absent or there is + only one category. Indices come from ``range(category_size)`` (not + random) so behavior is deterministic. + """ + if "category" not in dataset.column_names: + return dataset.select(range(n)) + cat_to_rows: dict[str, list[int]] = {} + for i, c in enumerate(dataset["category"]): + cat_to_rows.setdefault(c, []).append(i) + if len(cat_to_rows) <= 1: + return dataset.select(range(n)) + cat_lists = list(cat_to_rows.values()) + interleaved: list[int] = [] + max_len = max(len(c) for c in cat_lists) + for i in range(max_len): + for c in cat_lists: + if i < len(c): + interleaved.append(c[i]) + if len(interleaved) == n: + return dataset.select(interleaved) + return dataset.select(interleaved) + def _resolve_external_data( self, dataset: "Dataset", speed_config: config_type | str ) -> "Dataset":