Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,17 @@ class SseConnectionParams(BaseModel):
server.
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
server.
httpx_client_factory: Factory function to create a custom HTTPX client. If
not provided, a default factory will be used.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

url: str
headers: dict[str, Any] | None = None
timeout: float = 5.0
sse_read_timeout: float = 60 * 5.0
httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client


@runtime_checkable
Expand Down Expand Up @@ -398,6 +403,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
headers=merged_headers,
timeout=self._connection_params.timeout,
sse_read_timeout=self._connection_params.sse_read_timeout,
httpx_client_factory=self._connection_params.httpx_client_factory,
)
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
client = streamablehttp_client(
Expand Down
43 changes: 43 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,49 @@ def test_init_with_sse_connection_params(self):

assert manager._connection_params == sse_params

@patch("google.adk.tools.mcp_tool.mcp_session_manager.sse_client")
def test_init_with_sse_custom_httpx_factory(self, mock_sse_client):
"""Test that sse_client is called with custom httpx_client_factory."""
custom_httpx_factory = Mock()

sse_params = SseConnectionParams(
url="https://example.com/mcp",
timeout=10.0,
httpx_client_factory=custom_httpx_factory,
)
manager = MCPSessionManager(sse_params)

manager._create_client()

mock_sse_client.assert_called_once_with(
url="https://example.com/mcp",
headers=None,
timeout=10.0,
sse_read_timeout=300.0,
httpx_client_factory=custom_httpx_factory,
)

@patch("google.adk.tools.mcp_tool.mcp_session_manager.sse_client")
def test_init_with_sse_default_httpx_factory(self, mock_sse_client):
"""Test that sse_client is called with default httpx_client_factory."""
sse_params = SseConnectionParams(
url="https://example.com/mcp",
timeout=10.0,
)
manager = MCPSessionManager(sse_params)

manager._create_client()

mock_sse_client.assert_called_once_with(
url="https://example.com/mcp",
headers=None,
timeout=10.0,
sse_read_timeout=300.0,
httpx_client_factory=SseConnectionParams.model_fields[
"httpx_client_factory"
].get_default(),
)

def test_init_with_streamable_http_params(self):
"""Test initialization with StreamableHTTPConnectionParams."""
http_params = StreamableHTTPConnectionParams(
Expand Down