From f56c373a279e4a61c761d7a2f17e6e6547594e32 Mon Sep 17 00:00:00 2001 From: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com> Date: Tue, 21 Apr 2026 15:53:46 +0530 Subject: [PATCH] feat: add BedrockTransport + wire all Bedrock transport paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add BedrockTransport wrapping agent/bedrock_adapter.py behind the ProviderTransport ABC. Fourth and final transport. Wire ALL transport methods to production paths in run_agent.py: - build_kwargs: _build_api_kwargs bedrock branch (L6713) - normalize_response: main normalize loop, new bedrock_converse branch (handles both raw boto3 dicts and already-normalized SimpleNamespace) - validate_response: response validation, new bedrock_converse branch - finish_reason: new bedrock_converse branch in finish_reason extraction The truncation path (L9588) intentionally groups bedrock with chat_completions — both have the same response.choices shape because normalize_converse_response runs at the dispatch site. 17 new tests. 231 bedrock/converse/transport tests pass (0 failures). PR 6 of the provider transport refactor. --- agent/transports/__init__.py | 8 + agent/transports/bedrock.py | 154 ++++++++++++++++ run_agent.py | 70 ++++++-- .../transports/test_bedrock_transport.py | 164 ++++++++++++++++++ 4 files changed, 383 insertions(+), 13 deletions(-) create mode 100644 agent/transports/bedrock.py create mode 100644 tests/agent/transports/test_bedrock_transport.py diff --git a/agent/transports/__init__.py b/agent/transports/__init__.py index 6cd3a277a10..5184be5c7ed 100644 --- a/agent/transports/__init__.py +++ b/agent/transports/__init__.py @@ -37,3 +37,11 @@ def _discover_transports() -> None: import agent.transports.anthropic # noqa: F401 except ImportError: pass + try: + import agent.transports.codex # noqa: F401 + except ImportError: + pass + try: + import agent.transports.bedrock # noqa: F401 + except ImportError: + pass diff --git a/agent/transports/bedrock.py b/agent/transports/bedrock.py new file mode 100644 index 00000000000..af549e7eae6 --- /dev/null +++ b/agent/transports/bedrock.py @@ -0,0 +1,154 @@ +"""AWS Bedrock Converse API transport. + +Delegates to the existing adapter functions in agent/bedrock_adapter.py. +Bedrock uses its own boto3 client (not the OpenAI SDK), so the transport +owns format conversion and normalization, while client construction and +boto3 calls stay on AIAgent. +""" + +from typing import Any, Dict, List, Optional + +from agent.transports.base import ProviderTransport +from agent.transports.types import NormalizedResponse, ToolCall, Usage + + +class BedrockTransport(ProviderTransport): + """Transport for api_mode='bedrock_converse'.""" + + @property + def api_mode(self) -> str: + return "bedrock_converse" + + def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any: + """Convert OpenAI messages to Bedrock Converse format.""" + from agent.bedrock_adapter import convert_messages_to_converse + return convert_messages_to_converse(messages) + + def convert_tools(self, tools: List[Dict[str, Any]]) -> Any: + """Convert OpenAI tool schemas to Bedrock Converse toolConfig.""" + from agent.bedrock_adapter import convert_tools_to_converse + return convert_tools_to_converse(tools) + + def build_kwargs( + self, + model: str, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + **params, + ) -> Dict[str, Any]: + """Build Bedrock converse() kwargs. + + Calls convert_messages and convert_tools internally. + + params: + max_tokens: int — output token limit (default 4096) + temperature: float | None + guardrail_config: dict | None — Bedrock guardrails + region: str — AWS region (default 'us-east-1') + """ + from agent.bedrock_adapter import build_converse_kwargs + + region = params.get("region", "us-east-1") + guardrail = params.get("guardrail_config") + + kwargs = build_converse_kwargs( + model=model, + messages=messages, + tools=tools, + max_tokens=params.get("max_tokens", 4096), + temperature=params.get("temperature"), + guardrail_config=guardrail, + ) + # Sentinel keys for dispatch — agent pops these before the boto3 call + kwargs["__bedrock_converse__"] = True + kwargs["__bedrock_region__"] = region + return kwargs + + def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse: + """Normalize Bedrock response to NormalizedResponse. + + Handles two shapes: + 1. Raw boto3 dict (from direct converse() calls) + 2. Already-normalized SimpleNamespace with .choices (from dispatch site) + """ + from agent.bedrock_adapter import normalize_converse_response + + # Normalize to OpenAI-compatible SimpleNamespace + if hasattr(response, "choices") and response.choices: + # Already normalized at dispatch site + ns = response + else: + # Raw boto3 dict + ns = normalize_converse_response(response) + + choice = ns.choices[0] + msg = choice.message + finish_reason = choice.finish_reason or "stop" + + tool_calls = None + if msg.tool_calls: + tool_calls = [ + ToolCall( + id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + ) + for tc in msg.tool_calls + ] + + usage = None + if hasattr(ns, "usage") and ns.usage: + u = ns.usage + usage = Usage( + prompt_tokens=getattr(u, "prompt_tokens", 0) or 0, + completion_tokens=getattr(u, "completion_tokens", 0) or 0, + total_tokens=getattr(u, "total_tokens", 0) or 0, + ) + + reasoning = getattr(msg, "reasoning", None) or getattr(msg, "reasoning_content", None) + + return NormalizedResponse( + content=msg.content, + tool_calls=tool_calls, + finish_reason=finish_reason, + reasoning=reasoning, + usage=usage, + ) + + def validate_response(self, response: Any) -> bool: + """Check Bedrock response structure. + + After normalize_converse_response, the response has OpenAI-compatible + .choices — same check as chat_completions. + """ + if response is None: + return False + # Raw Bedrock dict response — check for 'output' key + if isinstance(response, dict): + return "output" in response + # Already-normalized SimpleNamespace + if hasattr(response, "choices"): + return bool(response.choices) + return False + + def map_finish_reason(self, raw_reason: str) -> str: + """Map Bedrock stop reason to OpenAI finish_reason. + + The adapter already does this mapping inside normalize_converse_response, + so this is only used for direct access to raw responses. + """ + _MAP = { + "end_turn": "stop", + "tool_use": "tool_calls", + "max_tokens": "length", + "stop_sequence": "stop", + "guardrail_intervened": "content_filter", + "content_filtered": "content_filter", + } + return _MAP.get(raw_reason, "stop") + + +# Auto-register on import +from agent.transports import register_transport # noqa: E402 + +register_transport("bedrock_converse", BedrockTransport) diff --git a/run_agent.py b/run_agent.py index 722f7cea4b3..eefaf5565af 100644 --- a/run_agent.py +++ b/run_agent.py @@ -6554,6 +6554,24 @@ def _get_anthropic_transport(self): self._anthropic_transport = t return t + def _get_codex_transport(self): + """Return the cached ResponsesApiTransport instance (lazy singleton).""" + t = getattr(self, "_codex_transport", None) + if t is None: + from agent.transports import get_transport + t = get_transport("codex_responses") + self._codex_transport = t + return t + + def _get_bedrock_transport(self): + """Return the cached BedrockTransport instance (lazy singleton).""" + t = getattr(self, "_bedrock_transport", None) + if t is None: + from agent.transports import get_transport + t = get_transport("bedrock_converse") + self._bedrock_transport = t + return t + def _prepare_anthropic_messages_for_api(self, api_messages: list) -> list: if not any( isinstance(msg, dict) and self._content_has_image_parts(msg.get("content")) @@ -6693,21 +6711,17 @@ def _build_api_kwargs(self, api_messages: list) -> dict: # AWS Bedrock native Converse API — bypasses the OpenAI client entirely. # The adapter handles message/tool conversion and boto3 calls directly. if self.api_mode == "bedrock_converse": - from agent.bedrock_adapter import build_converse_kwargs + _bt = self._get_bedrock_transport() region = getattr(self, "_bedrock_region", None) or "us-east-1" guardrail = getattr(self, "_bedrock_guardrail_config", None) - return { - "__bedrock_converse__": True, - "__bedrock_region__": region, - **build_converse_kwargs( - model=self.model, - messages=api_messages, - tools=self.tools, - max_tokens=self.max_tokens or 4096, - temperature=None, # Let the model use its default - guardrail_config=guardrail, - ), - } + return _bt.build_kwargs( + model=self.model, + messages=api_messages, + tools=self.tools, + max_tokens=self.max_tokens or 4096, + region=region, + guardrail_config=guardrail, + ) if self.api_mode == "codex_responses": instructions = "" @@ -9372,6 +9386,14 @@ def _stop_spinner(): error_details.append("response is None") else: error_details.append("response.content invalid (not a non-empty list)") + elif self.api_mode == "bedrock_converse": + _btv = self._get_bedrock_transport() + if not _btv.validate_response(response): + response_invalid = True + if response is None: + error_details.append("response is None") + else: + error_details.append("Bedrock response invalid (no output or choices)") else: if response is None or not hasattr(response, 'choices') or response.choices is None or not response.choices: response_invalid = True @@ -9534,6 +9556,10 @@ def _stop_spinner(): elif self.api_mode == "anthropic_messages": _tfr = self._get_anthropic_transport() finish_reason = _tfr.map_finish_reason(response.stop_reason) + elif self.api_mode == "bedrock_converse": + # Bedrock response is already normalized at dispatch — finish_reason + # is already in OpenAI format via normalize_converse_response() + finish_reason = response.choices[0].finish_reason if hasattr(response, "choices") and response.choices else "stop" else: finish_reason = response.choices[0].finish_reason assistant_message = response.choices[0].message @@ -10811,6 +10837,24 @@ def _stop_spinner(): ), ) finish_reason = _nr.finish_reason + elif self.api_mode == "bedrock_converse": + _bt = self._get_bedrock_transport() + _bnr = _bt.normalize_response(response) + assistant_message = SimpleNamespace( + content=_bnr.content, + tool_calls=[ + SimpleNamespace( + id=tc.id, + type="function", + function=SimpleNamespace(name=tc.name, arguments=tc.arguments), + ) + for tc in (_bnr.tool_calls or []) + ] or None, + reasoning=_bnr.reasoning, + reasoning_content=None, + reasoning_details=None, + ) + finish_reason = _bnr.finish_reason else: assistant_message = response.choices[0].message diff --git a/tests/agent/transports/test_bedrock_transport.py b/tests/agent/transports/test_bedrock_transport.py new file mode 100644 index 00000000000..f9d78a31ce1 --- /dev/null +++ b/tests/agent/transports/test_bedrock_transport.py @@ -0,0 +1,164 @@ +"""Tests for the BedrockTransport.""" + +import json +import pytest +from types import SimpleNamespace + +from agent.transports import get_transport +from agent.transports.types import NormalizedResponse, ToolCall + + +@pytest.fixture +def transport(): + import agent.transports.bedrock # noqa: F401 + return get_transport("bedrock_converse") + + +class TestBedrockBasic: + + def test_api_mode(self, transport): + assert transport.api_mode == "bedrock_converse" + + def test_registered(self, transport): + assert transport is not None + + +class TestBedrockBuildKwargs: + + def test_basic_kwargs(self, transport): + msgs = [{"role": "user", "content": "Hello"}] + kw = transport.build_kwargs(model="anthropic.claude-3-5-sonnet-20241022-v2:0", messages=msgs) + assert kw["modelId"] == "anthropic.claude-3-5-sonnet-20241022-v2:0" + assert kw["__bedrock_converse__"] is True + assert kw["__bedrock_region__"] == "us-east-1" + assert "messages" in kw + + def test_custom_region(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=msgs, + region="eu-west-1", + ) + assert kw["__bedrock_region__"] == "eu-west-1" + + def test_max_tokens(self, transport): + msgs = [{"role": "user", "content": "Hi"}] + kw = transport.build_kwargs( + model="anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=msgs, + max_tokens=8192, + ) + assert kw["inferenceConfig"]["maxTokens"] == 8192 + + +class TestBedrockConvertTools: + + def test_convert_tools(self, transport): + tools = [{ + "type": "function", + "function": { + "name": "terminal", + "description": "Run commands", + "parameters": {"type": "object", "properties": {"command": {"type": "string"}}}, + } + }] + result = transport.convert_tools(tools) + assert len(result) == 1 + assert result[0]["toolSpec"]["name"] == "terminal" + + +class TestBedrockValidate: + + def test_none(self, transport): + assert transport.validate_response(None) is False + + def test_raw_dict_valid(self, transport): + assert transport.validate_response({"output": {"message": {}}}) is True + + def test_raw_dict_invalid(self, transport): + assert transport.validate_response({"error": "fail"}) is False + + def test_normalized_valid(self, transport): + r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))]) + assert transport.validate_response(r) is True + + +class TestBedrockMapFinishReason: + + def test_end_turn(self, transport): + assert transport.map_finish_reason("end_turn") == "stop" + + def test_tool_use(self, transport): + assert transport.map_finish_reason("tool_use") == "tool_calls" + + def test_max_tokens(self, transport): + assert transport.map_finish_reason("max_tokens") == "length" + + def test_guardrail(self, transport): + assert transport.map_finish_reason("guardrail_intervened") == "content_filter" + + def test_unknown(self, transport): + assert transport.map_finish_reason("unknown") == "stop" + + +class TestBedrockNormalize: + + def _make_bedrock_response(self, text="Hello", tool_calls=None, stop_reason="end_turn"): + """Build a raw Bedrock converse response dict.""" + content = [] + if text: + content.append({"text": text}) + if tool_calls: + for tc in tool_calls: + content.append({ + "toolUse": { + "toolUseId": tc["id"], + "name": tc["name"], + "input": tc["input"], + } + }) + return { + "output": {"message": {"role": "assistant", "content": content}}, + "stopReason": stop_reason, + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + } + + def test_text_response(self, transport): + raw = self._make_bedrock_response(text="Hello world") + nr = transport.normalize_response(raw) + assert isinstance(nr, NormalizedResponse) + assert nr.content == "Hello world" + assert nr.finish_reason == "stop" + + def test_tool_call_response(self, transport): + raw = self._make_bedrock_response( + text=None, + tool_calls=[{"id": "tool_1", "name": "terminal", "input": {"command": "ls"}}], + stop_reason="tool_use", + ) + nr = transport.normalize_response(raw) + assert nr.finish_reason == "tool_calls" + assert len(nr.tool_calls) == 1 + assert nr.tool_calls[0].name == "terminal" + + def test_already_normalized_response(self, transport): + """Test normalize_response handles already-normalized SimpleNamespace (from dispatch site).""" + pre_normalized = SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace( + content="Hello from Bedrock", + tool_calls=None, + reasoning=None, + reasoning_content=None, + ), + finish_reason="stop", + )], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + nr = transport.normalize_response(pre_normalized) + assert isinstance(nr, NormalizedResponse) + assert nr.content == "Hello from Bedrock" + assert nr.finish_reason == "stop" + assert nr.usage is not None + assert nr.usage.prompt_tokens == 10