Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions olive/cli/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser):
help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.",
)

lmeval_group.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
default=False,
help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).",
)

add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
Expand Down Expand Up @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict:
("evaluators", "evaluator", "model_class"),
None if self.args.backend == "auto" else self.args.backend,
),
(
("evaluators", "evaluator", "confirm_run_unsafe_code"),
self.args.confirm_run_unsafe_code or None,
Comment thread
natke marked this conversation as resolved.
Outdated
),
]

for keys, value in to_replace:
Expand Down
142 changes: 140 additions & 2 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl
raise NotImplementedError("Yet to be implemented!")

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
raise NotImplementedError("Yet to be implemented!")
raise NotImplementedError(
"generate_until is not supported by this model backend. "
"Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval."
)


@register_model("ort")
Expand Down Expand Up @@ -509,7 +512,16 @@ def __init__(
self.max_length = max_length
else:
self.max_length = genai_config["search"]["max_length"]
self._eot_token_id = genai_config["model"]["eos_token_id"]
eos = genai_config["model"]["eos_token_id"]
# eos_token_id can be a single int or a list of ints
if isinstance(eos, list):
if not eos:
raise ValueError("genai_config model.eos_token_id must not be an empty list")
self._eot_token_id = eos[0]
self.eos_token_ids = set(eos)
else:
self._eot_token_id = eos
self.eos_token_ids = {eos}
Comment thread
natke marked this conversation as resolved.
self.params = og.GeneratorParams(self.model)
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)

Expand Down Expand Up @@ -573,5 +585,131 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor
# seq dimension so the continuation slice still lands on the correct positions.
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is found or max tokens reached.

Supports generative evaluation tasks such as MBPP and HumanEval.
Each request is a tuple of (context_string, gen_kwargs_dict).
"""
results = []
for request in requests:
context = request.args[0]
gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {}

# Extract stop sequences
until = gen_kwargs.get("until", [])
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
until = [until]
until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq]
Comment thread
natke marked this conversation as resolved.
Outdated

# Extract generation parameters
max_gen_toks = gen_kwargs.get(
"max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens"))
)
try:
max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256
except (TypeError, ValueError):
max_gen_toks = 256
max_gen_toks = max(max_gen_toks, 0)
try:
temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0)
except (TypeError, ValueError):
temperature = 0.0
raw_do_sample = gen_kwargs.get("do_sample", None)
if raw_do_sample is None:
do_sample = temperature > 0
elif isinstance(raw_do_sample, bool):
do_sample = raw_do_sample
elif isinstance(raw_do_sample, str):
do_sample = raw_do_sample.lower() not in ("false", "0", "no", "")
else:
do_sample = bool(raw_do_sample)

# Tokenize the prompt
prompt_ids = self.tokenizer.encode(context).tolist()
prompt_len = len(prompt_ids)

Comment thread
natke marked this conversation as resolved.
# Compute total max_length: prompt + new tokens, capped by model limit
total_max_length = min(prompt_len + max_gen_toks, self.max_length)

# If the prompt already fills or exceeds the model limit, no generation is possible.
if prompt_len >= self.max_length or max_gen_toks == 0:
results.append("")
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, "")
continue

# Create fresh generator params per request to avoid state leakage
params = og.GeneratorParams(self.model)
search_options = {
"max_length": total_max_length,
"past_present_share_buffer": False,
}
if do_sample:
search_options["temperature"] = temperature
else:
search_options["temperature"] = 0.0
params.set_search_options(**search_options)
Comment thread
natke marked this conversation as resolved.

# Run generation token by token to check for stop sequences
generator = og.Generator(self.model, params)
generator.append_tokens([prompt_ids])

generated_token_ids = []
stop_found = False
# Character-based rolling tail wide enough to catch any stop sequence
# across chunk boundaries, regardless of how many tokens a stop string spans.
max_stop_len = max((len(s) for s in until), default=0)
tail = ""

while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_sequence(0)[-1]

# Check for EOS token(s)
if new_token in self.eos_token_ids:
break

generated_token_ids.append(new_token)

# Decode one token at a time only for stop-sequence tail detection.
# The final text is produced by decoding the full ID sequence so that
# tokenizer whitespace/punctuation normalisation is applied correctly.
if until:
chunk = self.tokenizer.decode([new_token])
tail = (tail + chunk)[-(max_stop_len + len(chunk)) :]
for stop_seq in until:
if stop_seq in tail:
stop_found = True
break
if stop_found:
break

# Decode full token sequence once for correct whitespace/punctuation handling.
full_text = self.tokenizer.decode(generated_token_ids) if generated_token_ids else ""

# Trim at the earliest stop sequence found in the final decoded text.
generated_text = full_text
if until:
earliest = None
for stop_seq in until:
idx = full_text.find(stop_seq)
if idx != -1 and (earliest is None or idx < earliest):
earliest = idx
if earliest is not None:
generated_text = full_text[:earliest]

results.append(generated_text)

# lm-eval cache hook
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, generated_text)

return results

def complete(self):
pass
24 changes: 15 additions & 9 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import collections
import inspect
import logging
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -1029,6 +1030,7 @@ def __init__(self, tasks: list[str], **kwargs):
self.ep = kwargs.get("execution_provider")
self.ep_options = kwargs.get("provider_options")
self.device = kwargs.get("device")
self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False)

def evaluate(
self,
Expand Down Expand Up @@ -1100,15 +1102,19 @@ def evaluate(
if self.tasks:
lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length)

results = simple_evaluate(
model=lmmodel,
tasks=self.tasks,
task_manager=TaskManager(),
log_samples=False,
batch_size=self.batch_size,
device=device,
limit=self.limit,
)
simple_evaluate_kwargs = {
"model": lmmodel,
"tasks": self.tasks,
"task_manager": TaskManager(),
"log_samples": False,
"batch_size": self.batch_size,
"device": device,
"limit": self.limit,
}
# Only pass confirm_run_unsafe_code when the installed lm-eval version supports it.
if "confirm_run_unsafe_code" in inspect.signature(simple_evaluate).parameters:
Comment thread
natke marked this conversation as resolved.
Outdated
simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code
results = simple_evaluate(**simple_evaluate_kwargs)

for task_name in sorted(results["results"].keys()):
metric_items = sorted(results["results"][task_name].items())
Expand Down
Loading
Loading