Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions olive/cli/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser):
help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.",
)

lmeval_group.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
default=False,
help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).",
)

add_logging_options(sub_parser)
add_save_config_file_options(sub_parser)
add_shared_cache_options(sub_parser)
Expand Down Expand Up @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict:
("evaluators", "evaluator", "model_class"),
None if self.args.backend == "auto" else self.args.backend,
),
(
("evaluators", "evaluator", "confirm_run_unsafe_code"),
self.args.confirm_run_unsafe_code or None,
Comment thread
natke marked this conversation as resolved.
Outdated
),
]

for keys, value in to_replace:
Expand Down
102 changes: 100 additions & 2 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl
raise NotImplementedError("Yet to be implemented!")

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
raise NotImplementedError("Yet to be implemented!")
raise NotImplementedError(
"generate_until is not supported by this model backend. "
"Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval."
)


@register_model("ort")
Expand Down Expand Up @@ -509,7 +512,14 @@ def __init__(
self.max_length = max_length
else:
self.max_length = genai_config["search"]["max_length"]
self._eot_token_id = genai_config["model"]["eos_token_id"]
eos = genai_config["model"]["eos_token_id"]
# eos_token_id can be a single int or a list of ints
if isinstance(eos, list):
self._eot_token_id = eos[0]
self._eos_token_ids = set(eos)
else:
self._eot_token_id = eos
self._eos_token_ids = {eos}
self.params = og.GeneratorParams(self.model)
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)

Expand Down Expand Up @@ -573,5 +583,93 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor
# seq dimension so the continuation slice still lands on the correct positions.
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is found or max tokens reached.

Supports generative evaluation tasks such as MBPP and HumanEval.
Each request is a tuple of (context_string, gen_kwargs_dict).
"""
results = []
for request in requests:
context = request.args[0]
gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {}

# Extract stop sequences
until = gen_kwargs.get("until", [])
if isinstance(until, str):
until = [until]
elif until is None:
until = []
elif not isinstance(until, list):
until = [until]
until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq]
Comment thread
natke marked this conversation as resolved.
Outdated

# Extract generation parameters
max_gen_toks = gen_kwargs.get("max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens")))
try:
max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256
except (TypeError, ValueError):
max_gen_toks = 256
max_gen_toks = max(max_gen_toks, 0)
temperature = gen_kwargs.get("temperature", 0.0)
do_sample = gen_kwargs.get("do_sample", temperature > 0)
Comment thread
natke marked this conversation as resolved.
Outdated

# Tokenize the prompt
prompt_ids = self.tokenizer.encode(context).tolist()
prompt_len = len(prompt_ids)

Comment thread
natke marked this conversation as resolved.
# Compute total max_length: prompt + new tokens, capped by model limit
total_max_length = min(prompt_len + max_gen_toks, self.max_length)

# Create fresh generator params per request to avoid state leakage
params = og.GeneratorParams(self.model)
search_options = {
"max_length": total_max_length,
"past_present_share_buffer": False,
"batch_size": 1,
}
if do_sample:
search_options["temperature"] = temperature
else:
search_options["temperature"] = 0.0
params.set_search_options(**search_options)
Comment thread
natke marked this conversation as resolved.

# Run generation token by token to check for stop sequences
generator = og.Generator(self.model, params)
generator.append_tokens([prompt_ids])

generated_ids = []
generated_text = ""
Comment thread
natke marked this conversation as resolved.
Outdated

while not generator.is_done():
generator.generate_next_token()
new_token = generator.get_sequence(0)[-1]

# Check for EOS token(s)
if new_token in self._eos_token_ids:
break

generated_ids.append(new_token)
generated_text += self.tokenizer.decode([new_token])

Comment thread
natke marked this conversation as resolved.
Outdated
# Check stop sequences against generated text
earliest_stop_idx = None
for stop_seq in until:
stop_idx = generated_text.find(stop_seq)
if stop_idx != -1 and (earliest_stop_idx is None or stop_idx < earliest_stop_idx):
earliest_stop_idx = stop_idx

if earliest_stop_idx is not None:
generated_text = generated_text[:earliest_stop_idx]
break

results.append(generated_text)

# lm-eval cache hook
if hasattr(request, "cache_hook") and request.cache_hook is not None:
request.cache_hook.add_partial("generate_until", request.args, generated_text)

return results

def complete(self):
pass
2 changes: 2 additions & 0 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ def __init__(self, tasks: list[str], **kwargs):
self.ep = kwargs.get("execution_provider")
self.ep_options = kwargs.get("provider_options")
self.device = kwargs.get("device")
self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False)

def evaluate(
self,
Expand Down Expand Up @@ -1108,6 +1109,7 @@ def evaluate(
batch_size=self.batch_size,
device=device,
limit=self.limit,
confirm_run_unsafe_code=self.confirm_run_unsafe_code,
)
Comment thread
natke marked this conversation as resolved.
Outdated

for task_name in sorted(results["results"].keys()):
Expand Down
Loading
Loading