diff --git a/agent/transports/__init__.py b/agent/transports/__init__.py index 6cd3a277a10..cb628252e7b 100644 --- a/agent/transports/__init__.py +++ b/agent/transports/__init__.py @@ -37,3 +37,11 @@ def _discover_transports() -> None: import agent.transports.anthropic # noqa: F401 except ImportError: pass + try: + import agent.transports.codex # noqa: F401 + except ImportError: + pass + try: + import agent.transports.chat_completions # noqa: F401 + except ImportError: + pass diff --git a/agent/transports/chat_completions.py b/agent/transports/chat_completions.py new file mode 100644 index 00000000000..58876846b80 --- /dev/null +++ b/agent/transports/chat_completions.py @@ -0,0 +1,320 @@ +"""OpenAI Chat Completions transport. + +Handles the default api_mode ('chat_completions') used by ~16 OpenAI-compatible +providers (OpenRouter, Nous, NVIDIA, Qwen, Ollama, DeepSeek, xAI, custom, etc.). + +Messages and tools are already in OpenAI format — convert_messages and +convert_tools are near-identity. The complexity lives in build_kwargs +which has provider-specific conditionals for max_tokens defaults, +reasoning configuration, temperature handling, and extra_body assembly. +""" + +import copy +import uuid +from typing import Any, Dict, List, Optional + +from agent.transports.base import ProviderTransport +from agent.transports.types import NormalizedResponse, ToolCall, Usage + + +from agent.prompt_builder import DEVELOPER_ROLE_MODELS + + +class ChatCompletionsTransport(ProviderTransport): + """Transport for api_mode='chat_completions'. + + The default path for OpenAI-compatible providers. + """ + + @property + def api_mode(self) -> str: + return "chat_completions" + + def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: + """Messages are already in OpenAI format — sanitize codex leaks only.""" + sanitized = messages + needs_sanitize = False + for msg in messages: + if not isinstance(msg, dict): + continue + if "codex_reasoning_items" in msg: + needs_sanitize = True + break + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list): + for tc in tool_calls: + if isinstance(tc, dict) and ("call_id" in tc or "response_item_id" in tc): + needs_sanitize = True + break + if needs_sanitize: + break + + if needs_sanitize: + sanitized = copy.deepcopy(messages) + for msg in sanitized: + if not isinstance(msg, dict): + continue + msg.pop("codex_reasoning_items", None) + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list): + for tc in tool_calls: + if isinstance(tc, dict): + tc.pop("call_id", None) + tc.pop("response_item_id", None) + + return sanitized + + def convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Tools are already in OpenAI format — identity.""" + return tools + + def build_kwargs( + self, + model: str, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + **params, + ) -> Dict[str, Any]: + """Build chat.completions.create() kwargs. + + This is the most complex transport method — it handles ~16 providers + via params rather than subclasses. + + params: + timeout: float — API call timeout + max_tokens: int | None — user-configured max tokens + ephemeral_max_output_tokens: int | None — one-shot override (error recovery) + max_tokens_param_fn: callable — returns {max_tokens: N} or {max_completion_tokens: N} + reasoning_config: dict | None + request_overrides: dict | None + session_id: str | None + model_lower: str — lowercase model name for pattern matching + base_url_lower: str — lowercase base URL + # Provider detection flags + is_openrouter: bool + is_nous: bool + is_qwen_portal: bool + is_github_models: bool + is_nvidia_nim: bool + is_custom_provider: bool + ollama_num_ctx: int | None + # Provider routing + provider_preferences: dict | None + # Qwen-specific + qwen_prepare_fn: callable | None — prep fn applied AFTER codex sanitization + qwen_metadata: dict | None + # Temperature + fixed_temperature: Any — from _fixed_temperature_for_model() + omit_temperature: bool + # Reasoning + supports_reasoning: bool + github_reasoning_extra: dict | None + # Claude on OpenRouter max output + anthropic_max_output: int | None + # Extra + extra_body_additions: dict | None — pre-built extra_body entries + """ + # Start with sanitized messages (codex field stripping) + sanitized = self.convert_messages(messages) + + # Qwen portal prep AFTER codex sanitization (must transform sanitized messages) + qwen_fn = params.get("qwen_prepare_fn") + if qwen_fn is not None: + sanitized = qwen_fn(sanitized) + + # Developer role swap for GPT-5/Codex models + model_lower = params.get("model_lower", (model or "").lower()) + if ( + sanitized + and isinstance(sanitized[0], dict) + and sanitized[0].get("role") == "system" + and any(p in model_lower for p in DEVELOPER_ROLE_MODELS) + ): + sanitized = list(sanitized) + sanitized[0] = {**sanitized[0], "role": "developer"} + + api_kwargs: Dict[str, Any] = { + "model": model, + "messages": sanitized, + } + + timeout = params.get("timeout") + if timeout is not None: + api_kwargs["timeout"] = timeout + + # Temperature + fixed_temp = params.get("fixed_temperature") + omit_temp = params.get("omit_temperature", False) + if omit_temp: + api_kwargs.pop("temperature", None) + elif fixed_temp is not None: + api_kwargs["temperature"] = fixed_temp + + # Qwen metadata + qwen_meta = params.get("qwen_metadata") + if qwen_meta: + api_kwargs["metadata"] = qwen_meta + + # Tools + if tools: + api_kwargs["tools"] = tools + + # max_tokens resolution + max_tokens_fn = params.get("max_tokens_param_fn") + ephemeral = params.get("ephemeral_max_output_tokens") + max_tokens = params.get("max_tokens") + anthropic_max_out = params.get("anthropic_max_output") + + if ephemeral is not None and max_tokens_fn: + api_kwargs.update(max_tokens_fn(ephemeral)) + elif max_tokens is not None and max_tokens_fn: + api_kwargs.update(max_tokens_fn(max_tokens)) + elif params.get("is_nvidia_nim") and max_tokens_fn: + api_kwargs.update(max_tokens_fn(16384)) + elif params.get("is_qwen_portal") and max_tokens_fn: + api_kwargs.update(max_tokens_fn(65536)) + elif anthropic_max_out is not None: + api_kwargs["max_tokens"] = anthropic_max_out + + # extra_body assembly + extra_body: Dict[str, Any] = {} + + is_openrouter = params.get("is_openrouter", False) + is_nous = params.get("is_nous", False) + is_github_models = params.get("is_github_models", False) + + provider_prefs = params.get("provider_preferences") + if provider_prefs and is_openrouter: + extra_body["provider"] = provider_prefs + + # Reasoning + if params.get("supports_reasoning", False): + if is_github_models: + gh_reasoning = params.get("github_reasoning_extra") + if gh_reasoning is not None: + extra_body["reasoning"] = gh_reasoning + else: + reasoning_config = params.get("reasoning_config") + if reasoning_config is not None: + rc = dict(reasoning_config) + if is_nous and rc.get("enabled") is False: + pass # omit for Nous when disabled + else: + extra_body["reasoning"] = rc + else: + extra_body["reasoning"] = {"enabled": True, "effort": "medium"} + + if is_nous: + extra_body["tags"] = ["product=hermes-agent"] + + # Ollama num_ctx + ollama_ctx = params.get("ollama_num_ctx") + if ollama_ctx: + options = extra_body.get("options", {}) + options["num_ctx"] = ollama_ctx + extra_body["options"] = options + + # Ollama/custom think=false + if params.get("is_custom_provider", False): + reasoning_config = params.get("reasoning_config") + if reasoning_config and isinstance(reasoning_config, dict): + _effort = (reasoning_config.get("effort") or "").strip().lower() + _enabled = reasoning_config.get("enabled", True) + if _effort == "none" or _enabled is False: + extra_body["think"] = False + + if params.get("is_qwen_portal"): + extra_body["vl_high_resolution_images"] = True + + # Merge any pre-built extra_body additions + additions = params.get("extra_body_additions") + if additions: + extra_body.update(additions) + + if extra_body: + api_kwargs["extra_body"] = extra_body + + # Request overrides last + overrides = params.get("request_overrides") + if overrides: + api_kwargs.update(overrides) + + return api_kwargs + + def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse: + """Normalize OpenAI ChatCompletion to NormalizedResponse. + + For chat_completions, this is near-identity — the response is + already in OpenAI format. + """ + choice = response.choices[0] + msg = choice.message + finish_reason = choice.finish_reason or "stop" + + tool_calls = None + if msg.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + ) + for tc in msg.tool_calls + ] + + usage = None + if hasattr(response, "usage") and response.usage: + u = response.usage + usage = Usage( + prompt_tokens=getattr(u, "prompt_tokens", 0) or 0, + completion_tokens=getattr(u, "completion_tokens", 0) or 0, + total_tokens=getattr(u, "total_tokens", 0) or 0, + ) + + # reasoning_content is used by some providers (DeepSeek, etc.) + reasoning = getattr(msg, "reasoning_content", None) or getattr(msg, "reasoning", None) + + # reasoning_details carries encrypted reasoning blocks for cross-turn replay + provider_data = None + rd = getattr(msg, "reasoning_details", None) + if rd: + provider_data = {"reasoning_details": rd} + + return NormalizedResponse( + content=msg.content, + tool_calls=tool_calls, + finish_reason=finish_reason, + reasoning=reasoning, + usage=usage, + provider_data=provider_data, + ) + + def validate_response(self, response: Any) -> bool: + """Check that response has valid choices.""" + if response is None: + return False + if not hasattr(response, "choices") or response.choices is None: + return False + if not response.choices: + return False + return True + + def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]: + """Extract OpenRouter/OpenAI cache stats from prompt_tokens_details.""" + usage = getattr(response, "usage", None) + if usage is None: + return None + details = getattr(usage, "prompt_tokens_details", None) + if details is None: + return None + cached = getattr(details, "cached_tokens", 0) or 0 + written = getattr(details, "cache_write_tokens", 0) or 0 + if cached or written: + return {"cached_tokens": cached, "creation_tokens": written} + return None + + +# Auto-register on import +from agent.transports import register_transport # noqa: E402 + +register_transport("chat_completions", ChatCompletionsTransport) diff --git a/run_agent.py b/run_agent.py index 722f7cea4b3..7ccd54546bc 100644 --- a/run_agent.py +++ b/run_agent.py @@ -6554,6 +6554,24 @@ def _get_anthropic_transport(self): self._anthropic_transport = t return t + def _get_codex_transport(self): + """Return the cached ResponsesApiTransport instance (lazy singleton).""" + t = getattr(self, "_codex_transport", None) + if t is None: + from agent.transports import get_transport + t = get_transport("codex_responses") + self._codex_transport = t + return t + + def _get_chat_completions_transport(self): + """Return the cached ChatCompletionsTransport instance (lazy singleton).""" + t = getattr(self, "_chat_completions_transport", None) + if t is None: + from agent.transports import get_transport + t = get_transport("chat_completions") + self._chat_completions_transport = t + return t + def _prepare_anthropic_messages_for_api(self, api_messages: list) -> list: if not any( isinstance(msg, dict) and self._content_has_image_parts(msg.get("content")) @@ -6639,34 +6657,6 @@ def _qwen_prepare_chat_messages(self, api_messages: list) -> list: return prepared - def _qwen_prepare_chat_messages_inplace(self, messages: list) -> None: - """In-place variant — mutates an already-copied message list.""" - if not messages: - return - - for msg in messages: - if not isinstance(msg, dict): - continue - content = msg.get("content") - if isinstance(content, str): - msg["content"] = [{"type": "text", "text": content}] - elif isinstance(content, list): - normalized_parts = [] - for part in content: - if isinstance(part, str): - normalized_parts.append({"type": "text", "text": part}) - elif isinstance(part, dict): - normalized_parts.append(part) - if normalized_parts: - msg["content"] = normalized_parts - - for msg in messages: - if isinstance(msg, dict) and msg.get("role") == "system": - content = msg.get("content") - if isinstance(content, list) and content and isinstance(content[-1], dict): - content[-1]["cache_control"] = {"type": "ephemeral"} - break - def _build_api_kwargs(self, api_messages: list) -> dict: """Build the keyword arguments dict for the active API mode.""" if self.api_mode == "anthropic_messages": @@ -6788,216 +6778,87 @@ def _build_api_kwargs(self, api_messages: list) -> dict: return kwargs - sanitized_messages = api_messages - needs_sanitization = False - for msg in api_messages: - if not isinstance(msg, dict): - continue - if "codex_reasoning_items" in msg: - needs_sanitization = True - break - - tool_calls = msg.get("tool_calls") - if isinstance(tool_calls, list): - for tool_call in tool_calls: - if not isinstance(tool_call, dict): - continue - if "call_id" in tool_call or "response_item_id" in tool_call: - needs_sanitization = True - break - if needs_sanitization: - break + # ── chat_completions (default) ───────────────────────────────────── + _ct = self._get_chat_completions_transport() - if needs_sanitization: - sanitized_messages = copy.deepcopy(api_messages) - for msg in sanitized_messages: - if not isinstance(msg, dict): - continue - - # Codex-only replay state must not leak into strict chat-completions APIs. - msg.pop("codex_reasoning_items", None) - - tool_calls = msg.get("tool_calls") - if isinstance(tool_calls, list): - for tool_call in tool_calls: - if isinstance(tool_call, dict): - tool_call.pop("call_id", None) - tool_call.pop("response_item_id", None) - - # Qwen portal: normalize content to list-of-dicts, inject cache_control. - # Must run AFTER codex sanitization so we transform the final messages. - # If sanitization already deepcopied, reuse that copy (in-place). - if self._is_qwen_portal(): - if sanitized_messages is api_messages: - # No sanitization was done — we need our own copy. - sanitized_messages = self._qwen_prepare_chat_messages(sanitized_messages) - else: - # Already a deepcopy — transform in place to avoid a second deepcopy. - self._qwen_prepare_chat_messages_inplace(sanitized_messages) + # Provider detection flags + _is_qwen = self._is_qwen_portal() + _is_or = self._is_openrouter_url() + _is_gh = ( + base_url_host_matches(self._base_url_lower, "models.github.ai") + or base_url_host_matches(self._base_url_lower, "api.githubcopilot.com") + ) + _is_nous = "nousresearch" in self._base_url_lower - # GPT-5 and Codex models respond better to 'developer' than 'system' - # for instruction-following. Swap the role at the API boundary so - # internal message representation stays uniform ("system"). - _model_lower = (self.model or "").lower() - if ( - sanitized_messages - and sanitized_messages[0].get("role") == "system" - and any(p in _model_lower for p in DEVELOPER_ROLE_MODELS) - ): - # Shallow-copy the list + first message only — rest stays shared. - sanitized_messages = list(sanitized_messages) - sanitized_messages[0] = {**sanitized_messages[0], "role": "developer"} + # Temperature + try: + from agent.auxiliary_client import _fixed_temperature_for_model, OMIT_TEMPERATURE + _ft = _fixed_temperature_for_model(self.model, self.base_url) + _omit_temp = _ft is OMIT_TEMPERATURE + _fixed_temp = _ft if not _omit_temp else None + except Exception: + _omit_temp = False + _fixed_temp = None - provider_preferences = {} + # Provider preferences + _prefs = {} if self.providers_allowed: - provider_preferences["only"] = self.providers_allowed + _prefs["only"] = self.providers_allowed if self.providers_ignored: - provider_preferences["ignore"] = self.providers_ignored + _prefs["ignore"] = self.providers_ignored if self.providers_order: - provider_preferences["order"] = self.providers_order + _prefs["order"] = self.providers_order if self.provider_sort: - provider_preferences["sort"] = self.provider_sort + _prefs["sort"] = self.provider_sort if self.provider_require_parameters: - provider_preferences["require_parameters"] = True + _prefs["require_parameters"] = True if self.provider_data_collection: - provider_preferences["data_collection"] = self.provider_data_collection + _prefs["data_collection"] = self.provider_data_collection - api_kwargs = { - "model": self.model, - "messages": sanitized_messages, - "timeout": self._resolved_api_call_timeout(), - } - try: - from agent.auxiliary_client import _fixed_temperature_for_model, OMIT_TEMPERATURE - except Exception: - _fixed_temperature_for_model = None - OMIT_TEMPERATURE = None - if _fixed_temperature_for_model is not None: - fixed_temperature = _fixed_temperature_for_model(self.model, self.base_url) - if fixed_temperature is OMIT_TEMPERATURE: - api_kwargs.pop("temperature", None) - elif fixed_temperature is not None: - api_kwargs["temperature"] = fixed_temperature - if self._is_qwen_portal(): - api_kwargs["metadata"] = { - "sessionId": self.session_id or "hermes", - "promptId": str(uuid.uuid4()), - } - if self.tools: - api_kwargs["tools"] = self.tools - - # ── max_tokens for chat_completions ────────────────────────────── - # Priority: ephemeral override (error recovery / length-continuation - # boost) > user-configured max_tokens > provider-specific defaults. - _ephemeral_out = getattr(self, "_ephemeral_max_output_tokens", None) - if _ephemeral_out is not None: - self._ephemeral_max_output_tokens = None # consume immediately - api_kwargs.update(self._max_tokens_param(_ephemeral_out)) - elif self.max_tokens is not None: - api_kwargs.update(self._max_tokens_param(self.max_tokens)) - elif "integrate.api.nvidia.com" in self._base_url_lower: - # NVIDIA NIM defaults to a very low max_tokens when omitted, - # causing models like GLM-4.7 to truncate immediately (thinking - # tokens alone exhaust the budget). 16384 provides adequate room. - api_kwargs.update(self._max_tokens_param(16384)) - elif self._is_qwen_portal(): - # Qwen Portal defaults to a very low max_tokens when omitted. - # Reasoning models (qwen3-coder-plus) exhaust that budget on - # thinking tokens alone, causing the portal to return - # finish_reason="stop" with truncated output — the agent sees - # this as an intentional stop and exits the loop. Send 65536 - # (the documented max output for qwen3-coder models) so the - # model has adequate output budget for tool calls. - api_kwargs.update(self._max_tokens_param(65536)) - elif (self._is_openrouter_url() or "nousresearch" in self._base_url_lower) and "claude" in (self.model or "").lower(): - # OpenRouter and Nous Portal translate requests to Anthropic's - # Messages API, which requires max_tokens as a mandatory field. - # When we omit it, the proxy picks a default that can be too - # low — the model spends its output budget on thinking and has - # almost nothing left for the actual response (especially large - # tool calls like write_file). Sending the model's real output - # limit ensures full capacity. + # Anthropic max output for Claude on OpenRouter/Nous + _ant_max = None + if (_is_or or _is_nous) and "claude" in (self.model or "").lower(): try: from agent.anthropic_adapter import _get_anthropic_max_output - _model_output_limit = _get_anthropic_max_output(self.model) - api_kwargs["max_tokens"] = _model_output_limit + _ant_max = _get_anthropic_max_output(self.model) except Exception: - pass # fail open — let the proxy pick its default + pass - extra_body = {} + # Ephemeral max output override + _ephemeral_out = getattr(self, "_ephemeral_max_output_tokens", None) + if _ephemeral_out is not None: + self._ephemeral_max_output_tokens = None # consume immediately - _is_openrouter = self._is_openrouter_url() - _is_github_models = ( - base_url_host_matches(self._base_url_lower, "models.github.ai") - or base_url_host_matches(self._base_url_lower, "api.githubcopilot.com") + return _ct.build_kwargs( + model=self.model, + messages=api_messages, + tools=self.tools, + timeout=self._resolved_api_call_timeout(), + max_tokens=self.max_tokens, + ephemeral_max_output_tokens=_ephemeral_out, + max_tokens_param_fn=self._max_tokens_param, + reasoning_config=self.reasoning_config, + request_overrides=self.request_overrides, + session_id=getattr(self, "session_id", None), + model_lower=(self.model or "").lower(), + base_url_lower=self._base_url_lower, + is_openrouter=_is_or, + is_nous=_is_nous, + is_qwen_portal=_is_qwen, + is_github_models=_is_gh, + is_nvidia_nim="integrate.api.nvidia.com" in self._base_url_lower, + is_custom_provider=self.provider == "custom", + ollama_num_ctx=self._ollama_num_ctx, + provider_preferences=_prefs or None, + qwen_prepare_fn=self._qwen_prepare_chat_messages if _is_qwen else None, + qwen_metadata={"sessionId": self.session_id or "hermes", "promptId": str(uuid.uuid4())} if _is_qwen else None, + fixed_temperature=_fixed_temp, + omit_temperature=_omit_temp, + supports_reasoning=self._supports_reasoning_extra_body(), + github_reasoning_extra=self._github_models_reasoning_extra_body() if _is_gh else None, + anthropic_max_output=_ant_max, ) - # Provider preferences (only, ignore, order, sort) are OpenRouter- - # specific. Only send to OpenRouter-compatible endpoints. - # TODO: Nous Portal will add transparent proxy support — re-enable - # for _is_nous when their backend is updated. - if provider_preferences and _is_openrouter: - extra_body["provider"] = provider_preferences - _is_nous = "nousresearch" in self._base_url_lower - - if self._supports_reasoning_extra_body(): - if _is_github_models: - github_reasoning = self._github_models_reasoning_extra_body() - if github_reasoning is not None: - extra_body["reasoning"] = github_reasoning - else: - if self.reasoning_config is not None: - rc = dict(self.reasoning_config) - # Nous Portal requires reasoning enabled — don't send - # enabled=false to it (would cause 400). - if _is_nous and rc.get("enabled") is False: - pass # omit reasoning entirely for Nous when disabled - else: - extra_body["reasoning"] = rc - else: - extra_body["reasoning"] = { - "enabled": True, - "effort": "medium" - } - - # Nous Portal product attribution - if _is_nous: - extra_body["tags"] = ["product=hermes-agent"] - - # Ollama num_ctx: override the 2048 default so the model actually - # uses the context window it was trained for. Passed via the OpenAI - # SDK's extra_body → options.num_ctx, which Ollama's OpenAI-compat - # endpoint forwards to the runner as --ctx-size. - if self._ollama_num_ctx: - options = extra_body.get("options", {}) - options["num_ctx"] = self._ollama_num_ctx - extra_body["options"] = options - - # Ollama / custom provider: pass think=false when reasoning is disabled. - # Ollama does not recognise the OpenRouter-style `reasoning` extra_body - # field, so we use its native `think` parameter instead. - # This prevents thinking-capable models (Qwen3, etc.) from generating - # blocks and producing empty-response errors when the user has - # set reasoning_effort: none. - if self.provider == "custom" and self.reasoning_config and isinstance(self.reasoning_config, dict): - _effort = (self.reasoning_config.get("effort") or "").strip().lower() - _enabled = self.reasoning_config.get("enabled", True) - if _effort == "none" or _enabled is False: - extra_body["think"] = False - - if self._is_qwen_portal(): - extra_body["vl_high_resolution_images"] = True - - if extra_body: - api_kwargs["extra_body"] = extra_body - - # Priority Processing / generic request overrides (e.g. service_tier). - # Applied last so overrides win over any defaults set above. - if self.request_overrides: - api_kwargs.update(self.request_overrides) - - return api_kwargs - def _supports_reasoning_extra_body(self) -> bool: """Return True when reasoning extra_body is safe to send for this route/model. @@ -9373,7 +9234,8 @@ def _stop_spinner(): else: error_details.append("response.content invalid (not a non-empty list)") else: - if response is None or not hasattr(response, 'choices') or response.choices is None or not response.choices: + _ctv = self._get_chat_completions_transport() + if not _ctv.validate_response(response): response_invalid = True if response is None: error_details.append("response is None") @@ -9536,6 +9398,7 @@ def _stop_spinner(): finish_reason = _tfr.map_finish_reason(response.stop_reason) else: finish_reason = response.choices[0].finish_reason + # Pre-extract for truncation heuristic (raw SDK message needed) assistant_message = response.choices[0].message if self._should_treat_stop_as_truncated( finish_reason, @@ -9843,10 +9706,10 @@ def _stop_spinner(): cached = _cache["cached_tokens"] if _cache else 0 written = _cache["creation_tokens"] if _cache else 0 else: - # OpenRouter uses prompt_tokens_details.cached_tokens - details = getattr(response.usage, 'prompt_tokens_details', None) - cached = getattr(details, 'cached_tokens', 0) or 0 if details else 0 - written = getattr(details, 'cache_write_tokens', 0) or 0 if details else 0 + _ctc = self._get_chat_completions_transport() + _cc_cache = _ctc.extract_cache_stats(response) + cached = _cc_cache["cached_tokens"] if _cc_cache else 0 + written = _cc_cache["creation_tokens"] if _cc_cache else 0 prompt = usage_dict["prompt_tokens"] hit_pct = (cached / prompt * 100) if prompt > 0 else 0 if not self.quiet_mode: @@ -10812,7 +10675,26 @@ def _stop_spinner(): ) finish_reason = _nr.finish_reason else: - assistant_message = response.choices[0].message + _cct = self._get_chat_completions_transport() + _cc_nr = _cct.normalize_response(response) + assistant_message = SimpleNamespace( + content=_cc_nr.content, + tool_calls=[ + SimpleNamespace( + id=tc.id, + type="function", + function=SimpleNamespace(name=tc.name, arguments=tc.arguments), + ) + for tc in (_cc_nr.tool_calls or []) + ] or None, + reasoning=_cc_nr.reasoning, + reasoning_content=None, + reasoning_details=( + _cc_nr.provider_data.get("reasoning_details") + if _cc_nr.provider_data else None + ), + ) + finish_reason = _cc_nr.finish_reason # Normalize content to string — some OpenAI-compatible servers # (llama-server, etc.) return content as a dict or list instead diff --git a/tests/agent/transports/test_chat_completions.py b/tests/agent/transports/test_chat_completions.py new file mode 100644 index 00000000000..7b599512bf1 --- /dev/null +++ b/tests/agent/transports/test_chat_completions.py @@ -0,0 +1,203 @@ +"""Tests for the ChatCompletionsTransport.""" + +import pytest +from types import SimpleNamespace + +from agent.transports import get_transport +from agent.transports.types import NormalizedResponse, ToolCall + + +@pytest.fixture +def transport(): + import agent.transports.chat_completions # noqa: F401 + return get_transport("chat_completions") + + +class TestChatCompletionsBasic: + + def test_api_mode(self, transport): + assert transport.api_mode == "chat_completions" + + def test_registered(self, transport): + assert transport is not None + + def test_convert_tools_identity(self, transport): + tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}] + assert transport.convert_tools(tools) is tools + + def test_convert_messages_no_codex_leaks(self, transport): + msgs = [{"role": "user", "content": "hi"}] + result = transport.convert_messages(msgs) + assert result is msgs # no copy needed + + def test_convert_messages_strips_codex_fields(self, transport): + msgs = [ + {"role": "assistant", "content": "ok", "codex_reasoning_items": [{"id": "rs_1"}], + "tool_calls": [{"id": "call_1", "call_id": "call_1", "response_item_id": "fc_1", + "type": "function", "function": {"name": "t", "arguments": "{}"}}]}, + ] + result = transport.convert_messages(msgs) + assert "codex_reasoning_items" not in result[0] + assert "call_id" not in result[0]["tool_calls"][0] + assert "response_item_id" not in result[0]["tool_calls"][0] + + +class TestChatCompletionsBuildKwargs: + + def test_basic_kwargs(self, transport): + msgs = [{"role": "user", "content": "Hello"}] + kw = transport.build_kwargs(model="gpt-4o", messages=msgs, timeout=30.0) + assert kw["model"] == "gpt-4o" + assert kw["messages"][0]["content"] == "Hello" + assert kw["timeout"] == 30.0 + + def test_developer_role_swap(self, transport): + msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}] + kw = transport.build_kwargs(model="gpt-5.4", messages=msgs, model_lower="gpt-5.4") + assert kw["messages"][0]["role"] == "developer" + + def test_no_developer_swap_for_non_gpt5(self, transport): + msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}] + kw = transport.build_kwargs(model="claude-sonnet-4", messages=msgs, model_lower="claude-sonnet-4") + assert kw["messages"][0]["role"] == "system" + + def test_tools_included(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}] + kw = transport.build_kwargs(model="gpt-4o", messages=msgs, tools=tools) + assert kw["tools"] == tools + + def test_openrouter_provider_prefs(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gpt-4o", messages=msgs, + is_openrouter=True, + provider_preferences={"only": ["openai"]}, + ) + assert kw["extra_body"]["provider"] == {"only": ["openai"]} + + def test_nous_tags(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs(model="gpt-4o", messages=msgs, is_nous=True) + assert kw["extra_body"]["tags"] == ["product=hermes-agent"] + + def test_reasoning_default(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gpt-4o", messages=msgs, + supports_reasoning=True, + ) + assert kw["extra_body"]["reasoning"] == {"enabled": True, "effort": "medium"} + + def test_ollama_num_ctx(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="llama3", messages=msgs, + ollama_num_ctx=32768, + ) + assert kw["extra_body"]["options"]["num_ctx"] == 32768 + + def test_custom_think_false(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="qwen3", messages=msgs, + is_custom_provider=True, + reasoning_config={"effort": "none"}, + ) + assert kw["extra_body"]["think"] is False + + def test_max_tokens_with_fn(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gpt-4o", messages=msgs, + max_tokens=4096, + max_tokens_param_fn=lambda n: {"max_tokens": n}, + ) + assert kw["max_tokens"] == 4096 + + def test_ephemeral_overrides_max_tokens(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gpt-4o", messages=msgs, + max_tokens=4096, + ephemeral_max_output_tokens=2048, + max_tokens_param_fn=lambda n: {"max_tokens": n}, + ) + assert kw["max_tokens"] == 2048 + + def test_request_overrides_last(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="gpt-4o", messages=msgs, + request_overrides={"service_tier": "priority"}, + ) + assert kw["service_tier"] == "priority" + + +class TestChatCompletionsValidate: + + def test_none(self, transport): + assert transport.validate_response(None) is False + + def test_no_choices(self, transport): + r = SimpleNamespace(choices=None) + assert transport.validate_response(r) is False + + def test_empty_choices(self, transport): + r = SimpleNamespace(choices=[]) + assert transport.validate_response(r) is False + + def test_valid(self, transport): + r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))]) + assert transport.validate_response(r) is True + + +class TestChatCompletionsNormalize: + + def test_text_response(self, transport): + r = SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None), + finish_reason="stop", + )], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + nr = transport.normalize_response(r) + assert isinstance(nr, NormalizedResponse) + assert nr.content == "Hello" + assert nr.finish_reason == "stop" + assert nr.tool_calls is None + + def test_tool_call_response(self, transport): + tc = SimpleNamespace( + id="call_123", + function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'), + ) + r = SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None), + finish_reason="tool_calls", + )], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + nr = transport.normalize_response(r) + assert len(nr.tool_calls) == 1 + assert nr.tool_calls[0].name == "terminal" + assert nr.tool_calls[0].id == "call_123" + + +class TestChatCompletionsCacheStats: + + def test_no_usage(self, transport): + r = SimpleNamespace(usage=None) + assert transport.extract_cache_stats(r) is None + + def test_no_details(self, transport): + r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=None)) + assert transport.extract_cache_stats(r) is None + + def test_with_cache(self, transport): + details = SimpleNamespace(cached_tokens=500, cache_write_tokens=100) + r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=details)) + result = transport.extract_cache_stats(r) + assert result == {"cached_tokens": 500, "creation_tokens": 100}