Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions olive/cli/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser):
help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.",
)

lmeval_group.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
default=False,
help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).",
)

add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
Expand Down Expand Up @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict:
("evaluators", "evaluator", "model_class"),
None if self.args.backend == "auto" else self.args.backend,
),
(
("evaluators", "evaluator", "confirm_run_unsafe_code"),
self.args.confirm_run_unsafe_code or None,
Comment thread
natke marked this conversation as resolved.
Outdated
),
]

for keys, value in to_replace:
Expand Down
128 changes: 126 additions & 2 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl
raise NotImplementedError("Yet to be implemented!")

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
raise NotImplementedError("Yet to be implemented!")
raise NotImplementedError(
"generate_until is not supported by this model backend. "
"Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval."
)


@register_model("ort")
Expand Down Expand Up @@ -509,7 +512,16 @@ def __init__(
self.max_length = max_length
else:
self.max_length = genai_config["search"]["max_length"]
self._eot_token_id = genai_config["model"]["eos_token_id"]
eos = genai_config["model"]["eos_token_id"]
# eos_token_id can be a single int or a list of ints
if isinstance(eos, list):
if not eos:
raise ValueError("genai_config model.eos_token_id must not be an empty list")
self._eot_token_id = eos[0]
self.eos_token_ids = set(eos)
else:
self._eot_token_id = eos
self.eos_token_ids = {eos}
Comment thread
natke marked this conversation as resolved.
self.params = og.GeneratorParams(self.model)
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)

Expand Down Expand Up @@ -573,5 +585,117 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor
# seq dimension so the continuation slice still lands on the correct positions.
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is found or max tokens reached.

Supports generative evaluation tasks such as MBPP and HumanEval.
Each request is a tuple of (context_string, gen_kwargs_dict).
"""
results = []
for request in requests:
context = request.args[0]
gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {}

# Extract stop sequences
until = gen_kwargs.get("until", [])
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
until = [until]
until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq]
Comment thread
natke marked this conversation as resolved.
Outdated

# Extract generation parameters
max_gen_toks = gen_kwargs.get(
"max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens"))
)
try:
max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256
except (TypeError, ValueError):
max_gen_toks = 256
max_gen_toks = max(max_gen_toks, 0)
try:
temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0)
except (TypeError, ValueError):
temperature = 0.0
do_sample = gen_kwargs.get("do_sample", temperature > 0)
Comment thread
natke marked this conversation as resolved.
Outdated

# Tokenize the prompt
prompt_ids = self.tokenizer.encode(context).tolist()
prompt_len = len(prompt_ids)

Comment thread
natke marked this conversation as resolved.
# Compute total max_length: prompt + new tokens, capped by model limit
total_max_length = min(prompt_len + max_gen_toks, self.max_length)

# If the prompt already fills or exceeds the model limit, no generation is possible.
if prompt_len >= self.max_length or max_gen_toks == 0:
results.append("")
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, "")
continue

# Create fresh generator params per request to avoid state leakage
params = og.GeneratorParams(self.model)
search_options = {
"max_length": total_max_length,
"past_present_share_buffer": False,
}
if do_sample:
search_options["temperature"] = temperature
else:
search_options["temperature"] = 0.0
params.set_search_options(**search_options)
Comment thread
natke marked this conversation as resolved.

# Run generation token by token to check for stop sequences
generator = og.Generator(self.model, params)
generator.append_tokens([prompt_ids])

generated_chunks = []
generated_len = 0 # running total character count, avoids O(n²) join for offset
stop_idx = None
# Character-based rolling tail wide enough to catch any stop sequence
# across chunk boundaries, regardless of how many tokens a stop string spans.
max_stop_len = max((len(s) for s in until), default=0)
tail = ""

while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_sequence(0)[-1]

# Check for EOS token(s)
if new_token in self.eos_token_ids:
break

chunk = self.tokenizer.decode([new_token])
generated_chunks.append(chunk)
generated_len += len(chunk)
Comment thread
natke marked this conversation as resolved.
Outdated

# Maintain a character-based tail of exactly max_stop_len + len(chunk) chars
# so stop sequences that span chunk boundaries are never missed.
if until:
tail = (tail + chunk)[-(max_stop_len + len(chunk)) :]
tail_offset = generated_len - len(tail)
earliest = None
for stop_seq in until:
idx = tail.find(stop_seq)
if idx != -1:
abs_idx = tail_offset + idx
if earliest is None or abs_idx < earliest:
earliest = abs_idx
if earliest is not None:
stop_idx = earliest
break

generated_text = "".join(generated_chunks) if stop_idx is None else "".join(generated_chunks)[:stop_idx]
Comment thread
natke marked this conversation as resolved.
Outdated
Comment thread
natke marked this conversation as resolved.
Outdated

results.append(generated_text)

# lm-eval cache hook
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, generated_text)

return results

def complete(self):
pass
28 changes: 19 additions & 9 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ def __init__(self, tasks: list[str], **kwargs):
self.ep = kwargs.get("execution_provider")
self.ep_options = kwargs.get("provider_options")
self.device = kwargs.get("device")
self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False)

def evaluate(
self,
Expand Down Expand Up @@ -1100,15 +1101,24 @@ def evaluate(
if self.tasks:
lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length)

results = simple_evaluate(
model=lmmodel,
tasks=self.tasks,
task_manager=TaskManager(),
log_samples=False,
batch_size=self.batch_size,
device=device,
limit=self.limit,
)
simple_evaluate_kwargs = {
"model": lmmodel,
"tasks": self.tasks,
"task_manager": TaskManager(),
"log_samples": False,
"batch_size": self.batch_size,
"device": device,
"limit": self.limit,
"confirm_run_unsafe_code": self.confirm_run_unsafe_code,
}
try:
results = simple_evaluate(**simple_evaluate_kwargs)
except TypeError as e:
if "confirm_run_unsafe_code" not in str(e):
raise
# Older lm-eval versions don't support confirm_run_unsafe_code; retry without it
simple_evaluate_kwargs.pop("confirm_run_unsafe_code")
results = simple_evaluate(**simple_evaluate_kwargs)
Comment thread
natke marked this conversation as resolved.
Outdated

for task_name in sorted(results["results"].keys()):
metric_items = sorted(results["results"][task_name].items())
Expand Down
Loading
Loading