diff --git a/olive/cli/benchmark.py b/olive/cli/benchmark.py index adad95773..a3b3f25e8 100644 --- a/olive/cli/benchmark.py +++ b/olive/cli/benchmark.py @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser): help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.", ) + lmeval_group.add_argument( + "--confirm_run_unsafe_code", + action="store_true", + default=None, + help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).", + ) + add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict: ("evaluators", "evaluator", "model_class"), None if self.args.backend == "auto" else self.args.backend, ), + ( + ("evaluators", "evaluator", "confirm_run_unsafe_code"), + True if self.args.confirm_run_unsafe_code else None, + ), ] for keys, value in to_replace: diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index fd69b066e..5f9ac921a 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl raise NotImplementedError("Yet to be implemented!") def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: - raise NotImplementedError("Yet to be implemented!") + raise NotImplementedError( + "generate_until is not supported by this model backend. " + "Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval." + ) @register_model("ort") @@ -509,7 +512,16 @@ def __init__( self.max_length = max_length else: self.max_length = genai_config["search"]["max_length"] - self._eot_token_id = genai_config["model"]["eos_token_id"] + eos = genai_config["model"]["eos_token_id"] + # eos_token_id can be a single int or a list of ints + if isinstance(eos, list): + if not eos: + raise ValueError("genai_config model.eos_token_id must not be an empty list") + self._eot_token_id = eos[0] + self.eos_token_ids = set(eos) + else: + self._eot_token_id = eos + self.eos_token_ids = {eos} self.params = og.GeneratorParams(self.model) self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False) @@ -573,5 +585,135 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor # seq dimension so the continuation slice still lands on the correct positions. return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab] + def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: + """Generate text until a stop sequence is found or max tokens reached. + + Supports generative evaluation tasks such as MBPP and HumanEval. + Each request is a tuple of (context_string, gen_kwargs_dict). + """ + results = [] + for request in tqdm(requests, desc="Running generate_until", disable=disable_tqdm): + context = request.args[0] + gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {} + + # Extract stop sequences — normalise str/None/tuple/other-iterables to list[str] + until = gen_kwargs.get("until", []) + if isinstance(until, str): + until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + try: + until = list(until) # handles tuple, set, generator, etc. + except TypeError: + until = [until] # non-iterable scalar fallback + until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq] + + # Extract generation parameters + max_gen_toks = gen_kwargs.get( + "max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens")) + ) + try: + max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256 + except (TypeError, ValueError): + max_gen_toks = 256 + max_gen_toks = max(max_gen_toks, 0) + try: + temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0) + except (TypeError, ValueError): + temperature = 0.0 + raw_do_sample = gen_kwargs.get("do_sample", None) + if raw_do_sample is None: + do_sample = temperature > 0 + elif isinstance(raw_do_sample, bool): + do_sample = raw_do_sample + elif isinstance(raw_do_sample, str): + do_sample = raw_do_sample.lower() not in ("false", "0", "no", "") + else: + do_sample = bool(raw_do_sample) + + # Tokenize the prompt + prompt_ids = self.tokenizer.encode(context).tolist() + prompt_len = len(prompt_ids) + + # Compute total max_length: prompt + new tokens, capped by model limit + total_max_length = min(prompt_len + max_gen_toks, self.max_length) + + # If the prompt already fills or exceeds the model limit, no generation is possible. + if prompt_len >= self.max_length or max_gen_toks == 0: + results.append("") + if hasattr(request, "cache_hook") and request.cache_hook is not None: + request.cache_hook.add_partial("generate_until", request.args, "") + continue + + # Create fresh generator params per request to avoid state leakage + params = og.GeneratorParams(self.model) + search_options = { + "max_length": total_max_length, + "past_present_share_buffer": False, + } + if do_sample: + search_options["do_sample"] = True + search_options["temperature"] = temperature + else: + search_options["temperature"] = 0.0 + params.set_search_options(**search_options) + + # Run generation token by token to check for stop sequences + generator = og.Generator(self.model, params) + generator.append_tokens([prompt_ids]) + + generated_token_ids = [] + stop_found = False + # Character-based rolling tail wide enough to catch any stop sequence + # across chunk boundaries, regardless of how many tokens a stop string spans. + max_stop_len = max((len(s) for s in until), default=0) + tail = "" + + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_sequence(0)[-1] + + # Check for EOS token(s) + if new_token in self.eos_token_ids: + break + + generated_token_ids.append(new_token) + + # Decode one token at a time only for stop-sequence tail detection. + # The final text is produced by decoding the full ID sequence so that + # tokenizer whitespace/punctuation normalisation is applied correctly. + if until: + chunk = self.tokenizer.decode([new_token]) + tail = (tail + chunk)[-(max_stop_len + len(chunk)) :] + for stop_seq in until: + if stop_seq in tail: + stop_found = True + break + if stop_found: + break + + # Decode full token sequence once for correct whitespace/punctuation handling. + full_text = self.tokenizer.decode(generated_token_ids) if generated_token_ids else "" + + # Trim at the earliest stop sequence found in the final decoded text. + generated_text = full_text + if until: + earliest = None + for stop_seq in until: + idx = full_text.find(stop_seq) + if idx != -1 and (earliest is None or idx < earliest): + earliest = idx + if earliest is not None: + generated_text = full_text[:earliest] + + results.append(generated_text) + + # lm-eval cache hook + if hasattr(request, "cache_hook") and request.cache_hook is not None: + request.cache_hook.add_partial("generate_until", request.args, generated_text) + + return results + def complete(self): pass diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 0814850a1..d29629920 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -3,11 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import collections +import inspect import logging import time from abc import ABC, abstractmethod from copy import deepcopy -from functools import partial +from functools import lru_cache, partial from numbers import Number from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union @@ -1016,6 +1017,15 @@ def _prepare_dataloader( return FileListCommonDataLoader(dataloader, model.io_config, batch_size=file_chunk_size) +@lru_cache(maxsize=1) +def _simple_evaluate_supports_unsafe_code(simple_evaluate_fn) -> bool: + """Check (cached) whether lm-eval's simple_evaluate accepts confirm_run_unsafe_code.""" + try: + return "confirm_run_unsafe_code" in inspect.signature(simple_evaluate_fn).parameters + except (TypeError, ValueError): + return False + + @Registry.register("LMEvaluator") class LMEvaluator(OliveEvaluator): def __init__(self, tasks: list[str], **kwargs): @@ -1029,6 +1039,7 @@ def __init__(self, tasks: list[str], **kwargs): self.ep = kwargs.get("execution_provider") self.ep_options = kwargs.get("provider_options") self.device = kwargs.get("device") + self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False) def evaluate( self, @@ -1100,15 +1111,19 @@ def evaluate( if self.tasks: lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length) - results = simple_evaluate( - model=lmmodel, - tasks=self.tasks, - task_manager=TaskManager(), - log_samples=False, - batch_size=self.batch_size, - device=device, - limit=self.limit, - ) + simple_evaluate_kwargs = { + "model": lmmodel, + "tasks": self.tasks, + "task_manager": TaskManager(), + "log_samples": False, + "batch_size": self.batch_size, + "device": device, + "limit": self.limit, + } + # Only pass confirm_run_unsafe_code when the installed lm-eval version supports it. + if _simple_evaluate_supports_unsafe_code(simple_evaluate): + simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code + results = simple_evaluate(**simple_evaluate_kwargs) for task_name in sorted(results["results"].keys()): metric_items = sorted(results["results"][task_name].items()) diff --git a/test/evaluator/conftest.py b/test/evaluator/conftest.py new file mode 100644 index 000000000..8617534db --- /dev/null +++ b/test/evaluator/conftest.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Inject a minimal onnxruntime_genai stub for generate_until unit tests. + +Ensures tests can run in environments where the real package is not installed. +The tests mock all ORT GenAI objects anyway, so the stub only needs to provide +importable names. +""" + +import sys +import types +from unittest.mock import MagicMock + + +def _ensure_ort_genai_stub(): + if "onnxruntime_genai" not in sys.modules: + stub = types.ModuleType("onnxruntime_genai") + stub.Generator = MagicMock + stub.GeneratorParams = MagicMock + sys.modules["onnxruntime_genai"] = stub + + +_ensure_ort_genai_stub() diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index e295d069a..5bc819a7b 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -496,9 +496,15 @@ class TestLMEvaluatorModelClass: def test_lm_evaluator_dispatches_to_requested_backend( self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock, model_class ): + import inspect + from olive.evaluator.olive_evaluator import LMEvaluator from olive.model.handler.onnx import ONNXModelHandler + def _fake_evaluate(model, tasks, task_manager=None, log_samples=True, batch_size=1, device="cpu", limit=None): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) simple_evaluate_mock.return_value = {"results": {}} get_model_mock.return_value = MagicMock(return_value=MagicMock()) @@ -510,3 +516,601 @@ def test_lm_evaluator_dispatches_to_requested_backend( evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) get_model_mock.assert_called_once_with(model_class) + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_passes_confirm_run_unsafe_code( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Give the mock a signature that includes confirm_run_unsafe_code so inspect.signature works. + def _fake_evaluate( + model, + tasks, + task_manager=None, + log_samples=True, + batch_size=1, + device="cpu", + limit=None, + confirm_run_unsafe_code=False, + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator( + tasks=["mbpp"], model_class="ortgenai", batch_size=1, max_length=128, confirm_run_unsafe_code=True + ) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + # Verify confirm_run_unsafe_code=True was passed to simple_evaluate + call_kwargs = simple_evaluate_mock.call_args[1] + assert call_kwargs["confirm_run_unsafe_code"] is True + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_confirm_run_unsafe_code_defaults_false( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Give the mock a signature that includes confirm_run_unsafe_code so inspect.signature works. + def _fake_evaluate( + model, + tasks, + task_manager=None, + log_samples=True, + batch_size=1, + device="cpu", + limit=None, + confirm_run_unsafe_code=False, + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator(tasks=["arc_easy"], model_class="ort", batch_size=1, max_length=128) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + # Verify confirm_run_unsafe_code defaults to False + call_kwargs = simple_evaluate_mock.call_args[1] + assert call_kwargs["confirm_run_unsafe_code"] is False + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_skips_confirm_run_unsafe_code_for_older_lm_eval( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + """When lm-eval lacks confirm_run_unsafe_code, the kwarg must not be passed.""" + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Mock a signature WITHOUT confirm_run_unsafe_code (simulates older lm-eval). + def _fake_evaluate_old( + model, tasks, task_manager=None, log_samples=True, batch_size=1, device="cpu", limit=None + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate_old) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator( + tasks=["mbpp"], model_class="ortgenai", batch_size=1, max_length=128, confirm_run_unsafe_code=True + ) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + call_kwargs = simple_evaluate_mock.call_args[1] + assert "confirm_run_unsafe_code" not in call_kwargs + + +@pytest.mark.skipif( + importlib.util.find_spec("lm_eval") is None, + reason="lm_eval not installed", +) +class TestLMEvalORTGenAIGenerateUntil: + """Unit tests for LMEvalORTGenAIEvaluator.generate_until.""" + + def _make_mock_request(self, context, gen_kwargs): + """Create a mock lm-eval Request object.""" + req = MagicMock() + req.args = (context, gen_kwargs) + req.cache_hook = MagicMock() + return req + + def _mock_encode(self, ids): + """Return a mock that behaves like tokenizer.encode() output (has .tolist()).""" + import numpy as np + + return np.array(ids) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_stops_on_eos(self, mock_params_cls, mock_gen_cls): + """Test that generation stops when EOS token is produced.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100, 200]) # 3-token prompt + evaluator.tokenizer.decode.return_value = "hello" + + # Generator produces one token then EOS + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), # first token + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("def foo():", {"until": ["\n"], "max_gen_toks": 100}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + # After EOS on second token, only first token was appended → decode called once + assert results[0] == "hello" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_stops_on_stop_sequence(self, mock_params_cls, mock_gen_cls): + """Test that generation stops and trims at stop sequence.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + + evaluator.tokenizer.decode.side_effect = ["he", "l", "lo\n world", "hello\n world"] + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False, False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), + MagicMock(__getitem__=lambda s, k: 51), + MagicMock(__getitem__=lambda s, k: 52), + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": ["\n"], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "hello" # trimmed at \n + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_respects_max_length(self, mock_params_cls, mock_gen_cls): + """Test that total_max_length = min(prompt_len + max_gen_toks, max_length).""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 50 # Small model limit + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode(list(range(40))) # 40-token prompt + evaluator.tokenizer.decode.return_value = "x" + + # Generator immediately done (max_length reached) + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("long prompt", {"until": ["\n"], "max_gen_toks": 100}) + + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + # Verify search options set max_length = min(40+100, 50) = 50 + set_search_call = mock_params_cls.return_value.set_search_options + call_kwargs = set_search_call.call_args[1] + assert call_kwargs["max_length"] == 50 + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_multiple_eos_tokens(self, mock_params_cls, mock_gen_cls): + """Test that any token in eos_token_ids triggers stop.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2, 151645, 151643} # Multiple EOS like Qwen + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "result" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + # Second EOS token in the set triggers stop + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), + MagicMock(__getitem__=lambda s, k: 151643), # alternate EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "result" + + def test_generate_until_until_string_converted_to_list(self): + """Test that a string 'until' value is converted to a list.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + evaluator.tokenizer.decode.return_value = "x\n" + + with patch("onnxruntime_genai.GeneratorParams"), patch("onnxruntime_genai.Generator") as mock_gen_cls: + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + # Pass until as string, not list + request = self._make_mock_request("p", {"until": "\n", "max_gen_toks": 10}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + # Should still find the stop sequence (string was converted to list) + assert "\n" not in results[0] + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_uses_earliest_stop_match(self, mock_params_cls, mock_gen_cls): + """Test that stop trimming uses earliest occurrence across all stop sequences.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello\nworld" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": ["", "\n"], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "hello" + + @pytest.mark.parametrize( + ("gen_kwargs", "expected_max_length"), + [ + (None, 261), # default 256 when gen_kwargs is not a dict + ({"max_gen_toks": "7"}, 12), # parse numeric string + ({"max_new_tokens": "bad"}, 261), # invalid value falls back to default + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_parses_max_tokens_robustly( + self, mock_params_cls, mock_gen_cls, gen_kwargs, expected_max_length + ): + """Test robust parsing and clamping of max token kwargs.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2, 3, 4, 5]) # 5-token prompt + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", gen_kwargs) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + assert call_kwargs["max_length"] == expected_max_length + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_does_not_pass_batch_size_to_search_options(self, mock_params_cls, mock_gen_cls): + """batch_size is not a valid set_search_options kwarg for ORT GenAI — must never be passed.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2]) + evaluator.tokenizer.decode.return_value = "hello" + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 64}) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + assert "batch_size" not in call_kwargs, ( + f"batch_size must not be passed to set_search_options, got: {call_kwargs}" + ) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_returns_empty_when_max_gen_toks_zero(self, mock_params_cls, mock_gen_cls): + """Test that clamping a negative max_tokens to zero returns an empty completion immediately.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2, 3, 4, 5]) # 5-token prompt + + request = self._make_mock_request("prompt", {"max_tokens": -8}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == [""] + mock_gen_cls.assert_not_called() # generator should never be created + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cls): + """Test generation decodes only new tokens while preserving output.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello" # returned for full-sequence decode + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 11), + MagicMock(__getitem__=lambda s, k: 12), + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": []}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == ["hello"] + # With no stop sequences, tokens are decoded once as a full sequence (not per-token). + decode_inputs = [call.args[0] for call in evaluator.tokenizer.decode.call_args_list] + assert decode_inputs == [[11, 12]] + + @pytest.mark.parametrize( + ("temperature_val", "expect_do_sample"), + [ + ("0.7", True), # string float should be coerced + (None, False), # None should fall back to 0.0 + (0.0, False), # zero means greedy + (0.5, True), # normal float + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_temperature_coercion( + self, mock_params_cls, mock_gen_cls, temperature_val, expect_do_sample + ): + """Test that temperature is safely coerced from string/None without errors.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + gen_kwargs = {"until": [], "max_gen_toks": 10} + if temperature_val is not None: + gen_kwargs["temperature"] = temperature_val + + request = self._make_mock_request("prompt", gen_kwargs) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + if expect_do_sample: + assert call_kwargs["temperature"] > 0 + assert call_kwargs.get("do_sample") is True + else: + assert call_kwargs["temperature"] == 0.0 + assert "do_sample" not in call_kwargs + + @pytest.mark.parametrize( + ("do_sample_val", "expect_sampling"), + [ + (True, True), # bool True → sampling on + (False, False), # bool False → greedy + ("true", True), # string "true" → sampling on + ("false", False), # string "false" → greedy (was truthy before fix) + ("0", False), # string "0" → greedy + ("1", True), # string "1" → sampling + (1, True), # int 1 → sampling + (0, False), # int 0 → greedy + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_coerces_do_sample(self, mock_params_cls, mock_gen_cls, do_sample_val, expect_sampling): + """do_sample must be coerced to a real bool so string 'false'/'0' are not truthy.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request( + "prompt", {"until": [], "max_gen_toks": 10, "do_sample": do_sample_val, "temperature": 0.7} + ) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + if expect_sampling: + assert call_kwargs["temperature"] > 0, f"Expected sampling for do_sample={do_sample_val!r}" + assert call_kwargs.get("do_sample") is True, ( + f"do_sample=True must be set in search_options for do_sample={do_sample_val!r}" + ) + else: + assert call_kwargs["temperature"] == 0.0, f"Expected greedy for do_sample={do_sample_val!r}" + assert "do_sample" not in call_kwargs, ( + f"do_sample must not be set when greedy for do_sample={do_sample_val!r}" + ) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_tuple_until(self, mock_params_cls, mock_gen_cls): + """Until as a tuple must not be wrapped as a single element — each string is a stop sequence.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello\n world" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + # Pass until as a tuple — previously this would silently produce no stop enforcement + request = self._make_mock_request("prompt", {"until": ("\n",), "max_gen_toks": 256}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results[0] == "hello", f"Expected stop at \\n but got: {results[0]!r}" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_processes_multiple_requests_independently(self, mock_params_cls, mock_gen_cls): + """Multiple requests must not share mutable state (tail, stop_found, token_ids).""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + # First request decodes to text with a stop; second decodes cleanly + evaluator.tokenizer.decode.side_effect = [ + "\n", # per-token tail for req 1 (stop sequence present) + "hello\n", # full-sequence decode for req 1 + "world", # full-sequence decode for req 2 (no stop) + ] + + mock_generator = MagicMock() + # Req 1: is_done=False → generates token 10 → stop seq found → break (no more is_done) + # Req 2: is_done=False → generates token 20 → is_done=True → exit loop + mock_generator.is_done.side_effect = [False, False, True] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 10), # req 1 token + MagicMock(__getitem__=lambda s, k: 20), # req 2 token + ] + mock_gen_cls.return_value = mock_generator + + req1 = self._make_mock_request("p1", {"until": ["\n"], "max_gen_toks": 64}) + req2 = self._make_mock_request("p2", {"until": [], "max_gen_toks": 64}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [req1, req2]) + + assert results[0] == "hello" # trimmed at \n + assert results[1] == "world" # no stop, full text + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_calls_cache_hook(self, mock_params_cls, mock_gen_cls): + """cache_hook.add_partial must be called with the final generated text.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + evaluator.tokenizer.decode.return_value = "hello" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 10), + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 64}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == ["hello"] + request.cache_hook.add_partial.assert_called_once_with("generate_until", request.args, "hello")