fix: Returns dict as result from McpTool

The `BaseTool` expects the run_async to return a json-serializable object. By model_dump the McpTool result explicitly can allow what ADK runtime sees is identical to what is persisted in the session event list.

Before the change, runtime sees CallToolResult instance and Session persists its serialized dict.

https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/types.py#L916-L922

PiperOrigin-RevId: 822465432
This commit is contained in:
Wei Sun (Jack)
2025-10-22 00:57:31 -07:00
committed by Copybara-Service
parent d4dc645478
commit 4df926388b
2 changed files with 42 additions and 25 deletions
+4 -4
View File
@@ -80,7 +80,7 @@ class McpTool(BaseAuthenticatedTool):
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
):
"""Initializes an MCPTool.
"""Initializes an McpTool.
This tool wraps an MCP Tool interface and uses a session manager to
communicate with the MCP server.
@@ -186,7 +186,7 @@ class McpTool(BaseAuthenticatedTool):
@override
async def _run_async_impl(
self, *, args, tool_context: ToolContext, credential: AuthCredential
):
) -> Dict[str, Any]:
"""Runs the tool asynchronously.
Args:
@@ -217,7 +217,7 @@ class McpTool(BaseAuthenticatedTool):
)
response = await session.call_tool(self._mcp_tool.name, arguments=args)
return response
return response.model_dump(exclude_none=True, mode="json")
async def _get_headers(
self, tool_context: ToolContext, credential: AuthCredential
@@ -282,7 +282,7 @@ class McpTool(BaseAuthenticatedTool):
!= APIKeyIn.header
):
error_msg = (
"MCPTool only supports header-based API key authentication."
"McpTool only supports header-based API key authentication."
" Configured location:"
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
)
+38 -21
View File
@@ -36,6 +36,8 @@ try:
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionDeclaration
from mcp.types import CallToolResult
from mcp.types import TextContent
except ImportError as e:
if sys.version_info < (3, 10):
# Create dummy classes to prevent NameError during test collection
@@ -47,6 +49,8 @@ except ImportError as e:
MCPTool = DummyClass
ToolContext = DummyClass
FunctionDeclaration = DummyClass
CallToolResult = DummyClass
TextContent = DummyClass
else:
raise e
@@ -150,9 +154,11 @@ class TestMCPTool:
mcp_session_manager=self.mock_session_manager,
)
# Mock the session response
expected_response = {"result": "success"}
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
# Mock the session response - must return CallToolResult
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}
@@ -161,7 +167,8 @@ class TestMCPTool:
args=args, tool_context=tool_context, credential=None
)
assert result == expected_response
# Verify the result matches the model_dump output
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
self.mock_session_manager.create_session.assert_called_once_with(
headers=None
)
@@ -184,9 +191,11 @@ class TestMCPTool:
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
)
# Mock the session response
expected_response = {"result": "success"}
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
# Mock the session response - must return CallToolResult
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}
@@ -195,7 +204,7 @@ class TestMCPTool:
args=args, tool_context=tool_context, credential=credential
)
assert result == expected_response
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
# Check that headers were passed correctly
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
@@ -322,7 +331,7 @@ class TestMCPTool:
with pytest.raises(
ValueError,
match="MCPTool only supports header-based API key authentication",
match="McpTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)
@@ -354,7 +363,7 @@ class TestMCPTool:
with pytest.raises(
ValueError,
match="MCPTool only supports header-based API key authentication",
match="McpTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)
@@ -460,9 +469,11 @@ class TestMCPTool:
auth_credential=auth_credential,
)
# Mock the session response
expected_response = {"result": "authenticated_success"}
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
# Mock the session response - must return CallToolResult
mcp_response = CallToolResult(
content=[TextContent(type="text", text="authenticated_success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}
@@ -471,7 +482,7 @@ class TestMCPTool:
args=args, tool_context=tool_context, credential=auth_credential
)
assert result == expected_response
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
# Check that headers were passed correctly with custom API key header
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
@@ -545,7 +556,7 @@ class TestMCPTool:
mock_logger.error.assert_called_once()
logged_message = mock_logger.error.call_args[0][0]
assert (
"MCPTool only supports header-based API key authentication"
"McpTool only supports header-based API key authentication"
in logged_message
)
@@ -652,8 +663,11 @@ class TestMCPTool:
header_provider=header_provider,
)
expected_response = {"result": "success"}
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
# Mock the session response - must return CallToolResult
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
tool_context._invocation_context = Mock()
@@ -663,7 +677,7 @@ class TestMCPTool:
args=args, tool_context=tool_context, credential=None
)
assert result == expected_response
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
header_provider.assert_called_once()
self.mock_session_manager.create_session.assert_called_once_with(
headers=expected_headers
@@ -688,8 +702,11 @@ class TestMCPTool:
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
)
expected_response = {"result": "success"}
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
# Mock the session response - must return CallToolResult
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
tool_context._invocation_context = Mock()
@@ -699,7 +716,7 @@ class TestMCPTool:
args=args, tool_context=tool_context, credential=credential
)
assert result == expected_response
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
header_provider.assert_called_once()
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args