From 0cc3d6d6d554ded04cde0ca8b75ee433c49cdcad Mon Sep 17 00:00:00 2001 From: Michael Jones Date: Tue, 18 Nov 2025 19:09:48 +0000 Subject: [PATCH] Feat/expose mcps streamable http custom httpx factory parameter (#2997) * feat: Add support for custom HTTPX client factory in StreamableHTTPConnectionParams * Update src/google/adk/tools/mcp_tool/mcp_session_manager.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * unit tested mock * provide default - httpx client factory can't be none * feat: Enhance StreamableHTTPConnectionParams with httpx_client_factory attribute * fmt * fmt * refactor: Rename test_init_with_streamable_http_none_httpx_factory to test_init_with_streamable_http_default_httpx_factory for clarity * isort * fmt --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Kathy Wu <108756731+wukath@users.noreply.github.com> --- .../adk/tools/mcp_tool/mcp_session_manager.py | 16 ++++++ .../mcp_tool/test_mcp_session_manager.py | 53 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index d95d48f2..2b5cd967 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -25,17 +25,22 @@ import sys from typing import Any from typing import Dict from typing import Optional +from typing import Protocol +from typing import runtime_checkable from typing import TextIO from typing import Union import anyio from pydantic import BaseModel +from pydantic import ConfigDict try: from mcp import ClientSession from mcp import StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import create_mcp_http_client + from mcp.client.streamable_http import McpHttpClientFactory from mcp.client.streamable_http import streamablehttp_client except ImportError as e: @@ -84,6 +89,11 @@ class SseConnectionParams(BaseModel): sse_read_timeout: float = 60 * 5.0 +@runtime_checkable +class CheckableMcpHttpClientFactory(McpHttpClientFactory, Protocol): + pass + + class StreamableHTTPConnectionParams(BaseModel): """Parameters for the MCP Streamable HTTP connection. @@ -99,13 +109,18 @@ class StreamableHTTPConnectionParams(BaseModel): Streamable HTTP server. terminate_on_close: Whether to terminate the MCP Streamable HTTP server when the connection is closed. + httpx_client_factory: Factory function to create a custom HTTPX client. If + not provided, a default factory will be used. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 terminate_on_close: bool = True + httpx_client_factory: CheckableMcpHttpClientFactory = create_mcp_http_client def retry_on_closed_resource(func): @@ -286,6 +301,7 @@ class MCPSessionManager: seconds=self._connection_params.sse_read_timeout ), terminate_on_close=self._connection_params.terminate_on_close, + httpx_client_factory=self._connection_params.httpx_client_factory, ) else: raise ValueError( diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 6c001ccf..0c0a05e5 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -146,6 +146,59 @@ class TestMCPSessionManager: assert manager._connection_params == http_params + @patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client") + def test_init_with_streamable_http_custom_httpx_factory( + self, mock_streamablehttp_client + ): + """Test that streamablehttp_client is called with custom httpx_client_factory.""" + from datetime import timedelta + + custom_httpx_factory = Mock() + + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", + timeout=15.0, + httpx_client_factory=custom_httpx_factory, + ) + manager = MCPSessionManager(http_params) + + manager._create_client() + + mock_streamablehttp_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=timedelta(seconds=15.0), + sse_read_timeout=timedelta(seconds=300.0), + terminate_on_close=True, + httpx_client_factory=custom_httpx_factory, + ) + + @pytest.mark.asyncio + @patch("google.adk.tools.mcp_tool.mcp_session_manager.streamablehttp_client") + async def test_init_with_streamable_http_default_httpx_factory( + self, mock_streamablehttp_client + ): + """Test that streamablehttp_client is called with custom httpx_client_factory.""" + from datetime import timedelta + + from mcp.client.streamable_http import create_mcp_http_client + + http_params = StreamableHTTPConnectionParams( + url="https://example.com/mcp", timeout=15.0 + ) + manager = MCPSessionManager(http_params) + + manager._create_client() + + mock_streamablehttp_client.assert_called_once_with( + url="https://example.com/mcp", + headers=None, + timeout=timedelta(seconds=15.0), + sse_read_timeout=timedelta(seconds=300.0), + terminate_on_close=True, + httpx_client_factory=create_mcp_http_client, + ) + def test_generate_session_key_stdio(self): """Test session key generation for stdio connections.""" manager = MCPSessionManager(self.mock_stdio_connection_params)