diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 545f81ee..09737f09 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -80,7 +80,7 @@ class McpTool(BaseAuthenticatedTool): Callable[[ReadonlyContext], Dict[str, str]] ] = None, ): - """Initializes an MCPTool. + """Initializes an McpTool. This tool wraps an MCP Tool interface and uses a session manager to communicate with the MCP server. @@ -186,7 +186,7 @@ class McpTool(BaseAuthenticatedTool): @override async def _run_async_impl( self, *, args, tool_context: ToolContext, credential: AuthCredential - ): + ) -> Dict[str, Any]: """Runs the tool asynchronously. Args: @@ -217,7 +217,7 @@ class McpTool(BaseAuthenticatedTool): ) response = await session.call_tool(self._mcp_tool.name, arguments=args) - return response + return response.model_dump(exclude_none=True, mode="json") async def _get_headers( self, tool_context: ToolContext, credential: AuthCredential @@ -282,7 +282,7 @@ class McpTool(BaseAuthenticatedTool): != APIKeyIn.header ): error_msg = ( - "MCPTool only supports header-based API key authentication." + "McpTool only supports header-based API key authentication." " Configured location:" f" {self._credentials_manager._auth_config.auth_scheme.in_}" ) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 37c2df89..3408bfa3 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -36,6 +36,8 @@ try: from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.tool_context import ToolContext from google.genai.types import FunctionDeclaration + from mcp.types import CallToolResult + from mcp.types import TextContent except ImportError as e: if sys.version_info < (3, 10): # Create dummy classes to prevent NameError during test collection @@ -47,6 +49,8 @@ except ImportError as e: MCPTool = DummyClass ToolContext = DummyClass FunctionDeclaration = DummyClass + CallToolResult = DummyClass + TextContent = DummyClass else: raise e @@ -150,9 +154,11 @@ class TestMCPTool: mcp_session_manager=self.mock_session_manager, ) - # Mock the session response - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) args = {"param1": "test_value"} @@ -161,7 +167,8 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=None ) - assert result == expected_response + # Verify the result matches the model_dump output + assert result == mcp_response.model_dump(exclude_none=True, mode="json") self.mock_session_manager.create_session.assert_called_once_with( headers=None ) @@ -184,9 +191,11 @@ class TestMCPTool: auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth ) - # Mock the session response - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) args = {"param1": "test_value"} @@ -195,7 +204,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=credential ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") # Check that headers were passed correctly self.mock_session_manager.create_session.assert_called_once() call_args = self.mock_session_manager.create_session.call_args @@ -322,7 +331,7 @@ class TestMCPTool: with pytest.raises( ValueError, - match="MCPTool only supports header-based API key authentication", + match="McpTool only supports header-based API key authentication", ): await tool._get_headers(tool_context, auth_credential) @@ -354,7 +363,7 @@ class TestMCPTool: with pytest.raises( ValueError, - match="MCPTool only supports header-based API key authentication", + match="McpTool only supports header-based API key authentication", ): await tool._get_headers(tool_context, auth_credential) @@ -460,9 +469,11 @@ class TestMCPTool: auth_credential=auth_credential, ) - # Mock the session response - expected_response = {"result": "authenticated_success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="authenticated_success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) args = {"param1": "test_value"} @@ -471,7 +482,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=auth_credential ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") # Check that headers were passed correctly with custom API key header self.mock_session_manager.create_session.assert_called_once() call_args = self.mock_session_manager.create_session.call_args @@ -545,7 +556,7 @@ class TestMCPTool: mock_logger.error.assert_called_once() logged_message = mock_logger.error.call_args[0][0] assert ( - "MCPTool only supports header-based API key authentication" + "McpTool only supports header-based API key authentication" in logged_message ) @@ -652,8 +663,11 @@ class TestMCPTool: header_provider=header_provider, ) - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) tool_context._invocation_context = Mock() @@ -663,7 +677,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=None ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") header_provider.assert_called_once() self.mock_session_manager.create_session.assert_called_once_with( headers=expected_headers @@ -688,8 +702,11 @@ class TestMCPTool: auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth ) - expected_response = {"result": "success"} - self.mock_session.call_tool = AsyncMock(return_value=expected_response) + # Mock the session response - must return CallToolResult + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) tool_context = Mock(spec=ToolContext) tool_context._invocation_context = Mock() @@ -699,7 +716,7 @@ class TestMCPTool: args=args, tool_context=tool_context, credential=credential ) - assert result == expected_response + assert result == mcp_response.model_dump(exclude_none=True, mode="json") header_provider.assert_called_once() self.mock_session_manager.create_session.assert_called_once() call_args = self.mock_session_manager.create_session.call_args