From 4df926388b6e9ebcf517fbacf2f5532fd73b0f71 Mon Sep 17 00:00:00 2001 From: "Wei Sun (Jack)" Date: Wed, 22 Oct 2025 00:57:31 -0700 Subject: [PATCH] fix: Returns dict as result from McpTool The `BaseTool` expects the run_async to return a json-serializable object. By model_dump the McpTool result explicitly can allow what ADK runtime sees is identical to what is persisted in the session event list. Before the change, runtime sees CallToolResult instance and Session persists its serialized dict. https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/types.py#L916-L922 PiperOrigin-RevId: 822465432 --- src/google/adk/tools/mcp_tool/mcp_tool.py | 8 +-- .../unittests/tools/mcp_tool/test_mcp_tool.py | 59 ++++++++++++------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 545f81ee..09737f09 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -80,7 +80,7 @@ class McpTool(BaseAuthenticatedTool): Callable[[ReadonlyContext], Dict[str, str]] ] = None, ): - """Initializes an MCPTool. + """Initializes an McpTool. This tool wraps an MCP Tool interface and uses a session manager to communicate with the MCP server. @@ -186,7 +186,7 @@ class McpTool(BaseAuthenticatedTool): @override async def _run_async_impl( self, *, args, tool_context: ToolContext, credential: AuthCredential - ): + ) -> Dict[str, Any]: """Runs the tool asynchronously. Args: @@ -217,7 +217,7 @@ class McpTool(BaseAuthenticatedTool): ) response = await session.call_tool(self._mcp_tool.name, arguments=args) - return response + return response.model_dump(exclude_none=True, mode="json") async def _get_headers( self, tool_context: ToolContext, credential: AuthCredential @@ -282,7 +282,7 @@ class McpTool(BaseAuthenticatedTool): != APIKeyIn.header ): error_msg = ( - "MCPTool only supports header-based API key authentication." + "McpTool only supports header-based API key authentication." " Configured location:" f" {self._credentials_manager._auth_config.auth_scheme.in_}" ) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 37c2df89..3408bfa3 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -36,6 +36,8 @@ try: from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.tool_context import ToolContext from google.genai.types import FunctionDeclaration + from mcp.types import CallToolResult + from mcp.types import TextContent except ImportError as e: if sys.version_info < (3, 10): # Create dummy classes to prevent NameError during test collection @@ -47,6 +49,8 @@ except ImportError as e: MCPTool = DummyClass ToolContext = DummyClass FunctionDeclaration = DummyClass + CallToolResult = DummyClass + TextContent = DummyClass else: raise e @@ -150,9 +154,11 @@ class TestMCPTool: mcp_session_manager=self.mock_session_manager, ) - # Mock the session response - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) args = {"param1": "test_value"} @@ -161,7 +167,8 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=None ) - assert result == expected_response + # Verify the result matches the model_dump output + assert result == mcp_response.model_dump(exclude_none=True, mode="json") self.mock_session_manager.create_session.assert_called_once_with( headers=None ) @@ -184,9 +191,11 @@ class TestMCPTool: auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth ) - # Mock the session response - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) args = {"param1": "test_value"} @@ -195,7 +204,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=credential ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") # Check that headers were passed correctly self.mock_session_manager.create_session.assert_called_once() call_args = self.mock_session_manager.create_session.call_args @@ -322,7 +331,7 @@ class TestMCPTool: with pytest.raises( ValueError, - match="MCPTool only supports header-based API key authentication", + match="McpTool only supports header-based API key authentication", ): await tool._get_headers(tool_context, auth_credential) @@ -354,7 +363,7 @@ class TestMCPTool: with pytest.raises( ValueError, - match="MCPTool only supports header-based API key authentication", + match="McpTool only supports header-based API key authentication", ): await tool._get_headers(tool_context, auth_credential) @@ -460,9 +469,11 @@ class TestMCPTool: auth_credential=auth_credential, ) - # Mock the session response - expected_response = {"result": "authenticated_success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="authenticated_success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) args = {"param1": "test_value"} @@ -471,7 +482,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=auth_credential ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") # Check that headers were passed correctly with custom API key header self.mock_session_manager.create_session.assert_called_once() call_args = self.mock_session_manager.create_session.call_args @@ -545,7 +556,7 @@ class TestMCPTool: mock_logger.error.assert_called_once() logged_message = mock_logger.error.call_args[0][0] assert ( - "MCPTool only supports header-based API key authentication" + "McpTool only supports header-based API key authentication" in logged_message ) @@ -652,8 +663,11 @@ class TestMCPTool: header_provider=header_provider, ) - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) tool_context._invocation_context = Mock() @@ -663,7 +677,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=None ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") header_provider.assert_called_once() self.mock_session_manager.create_session.assert_called_once_with( headers=expected_headers @@ -688,8 +702,11 @@ class TestMCPTool: auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth ) - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) tool_context._invocation_context = Mock() @@ -699,7 +716,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=credential ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") header_provider.assert_called_once() self.mock_session_manager.create_session.assert_called_once() call_args = self.mock_session_manager.create_session.call_args