You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Returns dict as result from McpTool
The `BaseTool` expects the run_async to return a json-serializable object. By model_dump the McpTool result explicitly can allow what ADK runtime sees is identical to what is persisted in the session event list. Before the change, runtime sees CallToolResult instance and Session persists its serialized dict. https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/types.py#L916-L922 PiperOrigin-RevId: 822465432
This commit is contained in:
committed by
Copybara-Service
parent
d4dc645478
commit
4df926388b
@@ -80,7 +80,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
Callable[[ReadonlyContext], Dict[str, str]]
|
||||
] = None,
|
||||
):
|
||||
"""Initializes an MCPTool.
|
||||
"""Initializes an McpTool.
|
||||
|
||||
This tool wraps an MCP Tool interface and uses a session manager to
|
||||
communicate with the MCP server.
|
||||
@@ -186,7 +186,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
@override
|
||||
async def _run_async_impl(
|
||||
self, *, args, tool_context: ToolContext, credential: AuthCredential
|
||||
):
|
||||
) -> Dict[str, Any]:
|
||||
"""Runs the tool asynchronously.
|
||||
|
||||
Args:
|
||||
@@ -217,7 +217,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
)
|
||||
|
||||
response = await session.call_tool(self._mcp_tool.name, arguments=args)
|
||||
return response
|
||||
return response.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
async def _get_headers(
|
||||
self, tool_context: ToolContext, credential: AuthCredential
|
||||
@@ -282,7 +282,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
!= APIKeyIn.header
|
||||
):
|
||||
error_msg = (
|
||||
"MCPTool only supports header-based API key authentication."
|
||||
"McpTool only supports header-based API key authentication."
|
||||
" Configured location:"
|
||||
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ try:
|
||||
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from mcp.types import CallToolResult
|
||||
from mcp.types import TextContent
|
||||
except ImportError as e:
|
||||
if sys.version_info < (3, 10):
|
||||
# Create dummy classes to prevent NameError during test collection
|
||||
@@ -47,6 +49,8 @@ except ImportError as e:
|
||||
MCPTool = DummyClass
|
||||
ToolContext = DummyClass
|
||||
FunctionDeclaration = DummyClass
|
||||
CallToolResult = DummyClass
|
||||
TextContent = DummyClass
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -150,9 +154,11 @@ class TestMCPTool:
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
# Mock the session response
|
||||
expected_response = {"result": "success"}
|
||||
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
|
||||
# Mock the session response - must return CallToolResult
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
args = {"param1": "test_value"}
|
||||
@@ -161,7 +167,8 @@ class TestMCPTool:
|
||||
args=args, tool_context=tool_context, credential=None
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
# Verify the result matches the model_dump output
|
||||
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
|
||||
self.mock_session_manager.create_session.assert_called_once_with(
|
||||
headers=None
|
||||
)
|
||||
@@ -184,9 +191,11 @@ class TestMCPTool:
|
||||
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
|
||||
)
|
||||
|
||||
# Mock the session response
|
||||
expected_response = {"result": "success"}
|
||||
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
|
||||
# Mock the session response - must return CallToolResult
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
args = {"param1": "test_value"}
|
||||
@@ -195,7 +204,7 @@ class TestMCPTool:
|
||||
args=args, tool_context=tool_context, credential=credential
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
|
||||
# Check that headers were passed correctly
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
@@ -322,7 +331,7 @@ class TestMCPTool:
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MCPTool only supports header-based API key authentication",
|
||||
match="McpTool only supports header-based API key authentication",
|
||||
):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
@@ -354,7 +363,7 @@ class TestMCPTool:
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MCPTool only supports header-based API key authentication",
|
||||
match="McpTool only supports header-based API key authentication",
|
||||
):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
@@ -460,9 +469,11 @@ class TestMCPTool:
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
# Mock the session response
|
||||
expected_response = {"result": "authenticated_success"}
|
||||
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
|
||||
# Mock the session response - must return CallToolResult
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="authenticated_success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
args = {"param1": "test_value"}
|
||||
@@ -471,7 +482,7 @@ class TestMCPTool:
|
||||
args=args, tool_context=tool_context, credential=auth_credential
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
|
||||
# Check that headers were passed correctly with custom API key header
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
@@ -545,7 +556,7 @@ class TestMCPTool:
|
||||
mock_logger.error.assert_called_once()
|
||||
logged_message = mock_logger.error.call_args[0][0]
|
||||
assert (
|
||||
"MCPTool only supports header-based API key authentication"
|
||||
"McpTool only supports header-based API key authentication"
|
||||
in logged_message
|
||||
)
|
||||
|
||||
@@ -652,8 +663,11 @@ class TestMCPTool:
|
||||
header_provider=header_provider,
|
||||
)
|
||||
|
||||
expected_response = {"result": "success"}
|
||||
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
|
||||
# Mock the session response - must return CallToolResult
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
tool_context._invocation_context = Mock()
|
||||
@@ -663,7 +677,7 @@ class TestMCPTool:
|
||||
args=args, tool_context=tool_context, credential=None
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
|
||||
header_provider.assert_called_once()
|
||||
self.mock_session_manager.create_session.assert_called_once_with(
|
||||
headers=expected_headers
|
||||
@@ -688,8 +702,11 @@ class TestMCPTool:
|
||||
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
|
||||
)
|
||||
|
||||
expected_response = {"result": "success"}
|
||||
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
|
||||
# Mock the session response - must return CallToolResult
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
tool_context._invocation_context = Mock()
|
||||
@@ -699,7 +716,7 @@ class TestMCPTool:
|
||||
args=args, tool_context=tool_context, credential=credential
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
|
||||
header_provider.assert_called_once()
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
|
||||
Reference in New Issue
Block a user