diff --git a/tests/unit/types/test_chat_completion_validation.py b/tests/unit/types/test_chat_completion_validation.py new file mode 100644 index 0000000..c991f33 --- /dev/null +++ b/tests/unit/types/test_chat_completion_validation.py @@ -0,0 +1,74 @@ +import pytest + +from tlm.utils.chat_completion_validation import _validate_chat_completion_params + + +def test_validate_chat_completion_params_allows_valid_openai_keys() -> None: + params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5} + + _validate_chat_completion_params(params) + + +def test_validate_chat_completion_params_allows_provider_as_none() -> None: + params = {"messages": [], "model": "gpt-4.1", "temperature": 0.5} + + _validate_chat_completion_params(params) + + +def test_validate_chat_completion_params_requires_messages() -> None: + params = {"model": "gpt-4.1-mini"} + + with pytest.raises(ValueError) as exc_info: + _validate_chat_completion_params(params) + + assert "openai_args must include the following parameter(s): messages" in str(exc_info.value) + + +def test_validate_chat_completion_params_requires_messages_list() -> None: + params = {"messages": "not-a-list"} + + with pytest.raises(ValueError) as exc_info: + _validate_chat_completion_params(params) + + assert "`messages` must be provided as a list" in str(exc_info.value) + + +def test_validate_chat_completion_params_requires_message_dict() -> None: + params = {"messages": ["not-a-dict"]} + + with pytest.raises(ValueError) as exc_info: + _validate_chat_completion_params(params) + + assert "messages[0] must be a dictionary" in str(exc_info.value) + + +def test_validate_chat_completion_params_requires_role_and_content_strings() -> None: + params = {"messages": [{"role": 123, "content": None}]} + + with pytest.raises(ValueError) as exc_info: + _validate_chat_completion_params(params) + + assert "messages[0]['role']" in str(exc_info.value) + + +def test_validate_chat_completion_params_allows_function_call_without_content() -> None: + params = { + "messages": [ + { + "role": "assistant", + "content": None, + "function_call": {"name": "foo", "arguments": '{"bar": 1}'}, + } + ] + } + + _validate_chat_completion_params(params) + + +def test_validate_chat_completion_params_requires_content_when_no_function_call() -> None: + params = {"messages": [{"role": "assistant", "content": None}]} + + with pytest.raises(ValueError) as exc_info: + _validate_chat_completion_params(params) + + assert "messages[0]['content'] must be a string." in str(exc_info.value) diff --git a/tlm/api.py b/tlm/api.py index c30beaa..0b4dc28 100644 --- a/tlm/api.py +++ b/tlm/api.py @@ -4,6 +4,7 @@ from tlm.config.presets import WorkflowType from tlm.inference import InferenceResult, tlm_inference from tlm.types import SemanticEval, CompletionParams +from tlm.utils.chat_completion_validation import _validate_chat_completion_params async def inference( @@ -14,6 +15,7 @@ async def inference( evals: list[SemanticEval] | None = None, config_input: ConfigInput = ConfigInput(), ) -> InferenceResult: + _validate_chat_completion_params(openai_args) workflow_type = WorkflowType.from_inference_params( openai_args=openai_args, score=response is not None, diff --git a/tlm/utils/chat_completion_validation.py b/tlm/utils/chat_completion_validation.py new file mode 100644 index 0000000..93109c9 --- /dev/null +++ b/tlm/utils/chat_completion_validation.py @@ -0,0 +1,54 @@ +"""Validation helpers for chat completion parameter dictionaries.""" + +from typing import Any, Callable, FrozenSet, Mapping + +from tlm.types.base import CompletionParams + + +ParamValidator = Callable[[Any], None] + +REQUIRED_CHAT_COMPLETION_PARAMS: FrozenSet[str] = frozenset({"messages"}) + + +def _validate_messages_param(messages: Any) -> None: + """Validate the shape of a `messages` param for chat completions.""" + + if not isinstance(messages, list): + raise ValueError("`messages` must be provided as a list of message dictionaries.") + + for index, message in enumerate(messages): + if not isinstance(message, dict): + raise ValueError(f"messages[{index}] must be a dictionary.") + + role = message.get("role") + content = message.get("content") + + if role is None or not isinstance(role, str): + raise ValueError(f"messages[{index}]['role'] must be a non-empty string.") + + if content is None or not isinstance(content, str): + function_call = message.get("function_call") + if role != "assistant": + raise ValueError(f"Non-assistant message at index {index} must have content.") + if function_call is None: + raise ValueError(f"Assistant message at index {index} must have content or a function call.") + + +REQUIRED_PARAM_VALIDATORS: Mapping[str, ParamValidator] = { + "messages": _validate_messages_param, +} + + +def _validate_chat_completion_params(params: CompletionParams) -> None: # type: ignore + """Ensure only supported chat completion params are passed into inference.""" + + missing_required = [param for param in REQUIRED_CHAT_COMPLETION_PARAMS if param not in params] + if missing_required: + required_str = ", ".join(sorted(REQUIRED_CHAT_COMPLETION_PARAMS)) + raise ValueError(f"openai_args must include the following parameter(s): {required_str}") + + for param in REQUIRED_CHAT_COMPLETION_PARAMS: + validator = REQUIRED_PARAM_VALIDATORS.get(param) + if validator is None: + continue + validator(params[param])