Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions olive/cli/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser):
help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.",
)

lmeval_group.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
default=None,
help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).",
)

add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
Expand Down Expand Up @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict:
("evaluators", "evaluator", "model_class"),
None if self.args.backend == "auto" else self.args.backend,
),
(
("evaluators", "evaluator", "confirm_run_unsafe_code"),
True if self.args.confirm_run_unsafe_code else None,
),
]

for keys, value in to_replace:
Expand Down
146 changes: 144 additions & 2 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl
raise NotImplementedError("Yet to be implemented!")

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
raise NotImplementedError("Yet to be implemented!")
raise NotImplementedError(
"generate_until is not supported by this model backend. "
"Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval."
)


@register_model("ort")
Expand Down Expand Up @@ -509,7 +512,16 @@ def __init__(
self.max_length = max_length
else:
self.max_length = genai_config["search"]["max_length"]
self._eot_token_id = genai_config["model"]["eos_token_id"]
eos = genai_config["model"]["eos_token_id"]
# eos_token_id can be a single int or a list of ints
if isinstance(eos, list):
if not eos:
raise ValueError("genai_config model.eos_token_id must not be an empty list")
self._eot_token_id = eos[0]
self.eos_token_ids = set(eos)
else:
self._eot_token_id = eos
self.eos_token_ids = {eos}
Comment thread
natke marked this conversation as resolved.
self.params = og.GeneratorParams(self.model)
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)

Expand Down Expand Up @@ -573,5 +585,135 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor
# seq dimension so the continuation slice still lands on the correct positions.
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is found or max tokens reached.

Supports generative evaluation tasks such as MBPP and HumanEval.
Each request is a tuple of (context_string, gen_kwargs_dict).
"""
results = []
for request in tqdm(requests, desc="Running generate_until", disable=disable_tqdm):
context = request.args[0]
gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {}

# Extract stop sequences — normalise str/None/tuple/other-iterables to list[str]
until = gen_kwargs.get("until", [])
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
try:
until = list(until) # handles tuple, set, generator, etc.
except TypeError:
until = [until] # non-iterable scalar fallback
until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq]

# Extract generation parameters
max_gen_toks = gen_kwargs.get(
"max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens"))
)
try:
max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256
except (TypeError, ValueError):
max_gen_toks = 256
max_gen_toks = max(max_gen_toks, 0)
try:
temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0)
except (TypeError, ValueError):
temperature = 0.0
raw_do_sample = gen_kwargs.get("do_sample", None)
if raw_do_sample is None:
do_sample = temperature > 0
elif isinstance(raw_do_sample, bool):
do_sample = raw_do_sample
elif isinstance(raw_do_sample, str):
do_sample = raw_do_sample.lower() not in ("false", "0", "no", "")
else:
do_sample = bool(raw_do_sample)

# Tokenize the prompt
prompt_ids = self.tokenizer.encode(context).tolist()
prompt_len = len(prompt_ids)

Comment thread
natke marked this conversation as resolved.
# Compute total max_length: prompt + new tokens, capped by model limit
total_max_length = min(prompt_len + max_gen_toks, self.max_length)

# If the prompt already fills or exceeds the model limit, no generation is possible.
if prompt_len >= self.max_length or max_gen_toks == 0:
results.append("")
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, "")
continue

# Create fresh generator params per request to avoid state leakage
params = og.GeneratorParams(self.model)
search_options = {
"max_length": total_max_length,
"past_present_share_buffer": False,
}
if do_sample:
search_options["do_sample"] = True
search_options["temperature"] = temperature
else:
search_options["temperature"] = 0.0
params.set_search_options(**search_options)
Comment thread
natke marked this conversation as resolved.

# Run generation token by token to check for stop sequences
generator = og.Generator(self.model, params)
generator.append_tokens([prompt_ids])

generated_token_ids = []
stop_found = False
# Character-based rolling tail wide enough to catch any stop sequence
# across chunk boundaries, regardless of how many tokens a stop string spans.
max_stop_len = max((len(s) for s in until), default=0)
tail = ""

while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_sequence(0)[-1]

# Check for EOS token(s)
if new_token in self.eos_token_ids:
break

generated_token_ids.append(new_token)

# Decode one token at a time only for stop-sequence tail detection.
# The final text is produced by decoding the full ID sequence so that
# tokenizer whitespace/punctuation normalisation is applied correctly.
if until:
chunk = self.tokenizer.decode([new_token])
tail = (tail + chunk)[-(max_stop_len + len(chunk)) :]
for stop_seq in until:
if stop_seq in tail:
stop_found = True
break
if stop_found:
break

# Decode full token sequence once for correct whitespace/punctuation handling.
full_text = self.tokenizer.decode(generated_token_ids) if generated_token_ids else ""

# Trim at the earliest stop sequence found in the final decoded text.
generated_text = full_text
if until:
earliest = None
for stop_seq in until:
idx = full_text.find(stop_seq)
if idx != -1 and (earliest is None or idx < earliest):
earliest = idx
if earliest is not None:
generated_text = full_text[:earliest]

results.append(generated_text)

# lm-eval cache hook
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, generated_text)

return results

def complete(self):
pass
35 changes: 25 additions & 10 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import collections
import inspect
import logging
import time
from abc import ABC, abstractmethod
from copy import deepcopy
from functools import partial
from functools import lru_cache, partial
from numbers import Number
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union
Expand Down Expand Up @@ -1016,6 +1017,15 @@ def _prepare_dataloader(
return FileListCommonDataLoader(dataloader, model.io_config, batch_size=file_chunk_size)


@lru_cache(maxsize=1)
def _simple_evaluate_supports_unsafe_code(simple_evaluate_fn) -> bool:
"""Check (cached) whether lm-eval's simple_evaluate accepts confirm_run_unsafe_code."""
try:
return "confirm_run_unsafe_code" in inspect.signature(simple_evaluate_fn).parameters
except (TypeError, ValueError):
return False


@Registry.register("LMEvaluator")
class LMEvaluator(OliveEvaluator):
def __init__(self, tasks: list[str], **kwargs):
Expand All @@ -1029,6 +1039,7 @@ def __init__(self, tasks: list[str], **kwargs):
self.ep = kwargs.get("execution_provider")
self.ep_options = kwargs.get("provider_options")
self.device = kwargs.get("device")
self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False)

def evaluate(
self,
Expand Down Expand Up @@ -1100,15 +1111,19 @@ def evaluate(
if self.tasks:
lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length)

results = simple_evaluate(
model=lmmodel,
tasks=self.tasks,
task_manager=TaskManager(),
log_samples=False,
batch_size=self.batch_size,
device=device,
limit=self.limit,
)
simple_evaluate_kwargs = {
"model": lmmodel,
"tasks": self.tasks,
"task_manager": TaskManager(),
"log_samples": False,
"batch_size": self.batch_size,
"device": device,
"limit": self.limit,
}
# Only pass confirm_run_unsafe_code when the installed lm-eval version supports it.
if _simple_evaluate_supports_unsafe_code(simple_evaluate):
simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code
results = simple_evaluate(**simple_evaluate_kwargs)

for task_name in sorted(results["results"].keys()):
metric_items = sorted(results["results"][task_name].items())
Expand Down
25 changes: 25 additions & 0 deletions test/evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Inject a minimal onnxruntime_genai stub for generate_until unit tests.

Ensures tests can run in environments where the real package is not installed.
The tests mock all ORT GenAI objects anyway, so the stub only needs to provide
importable names.
"""

import sys
import types
from unittest.mock import MagicMock


def _ensure_ort_genai_stub():
if "onnxruntime_genai" not in sys.modules:
stub = types.ModuleType("onnxruntime_genai")
stub.Generator = MagicMock
stub.GeneratorParams = MagicMock
sys.modules["onnxruntime_genai"] = stub


_ensure_ort_genai_stub()
Loading
Loading