diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index b4d5682e..2e79b820 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -15,17 +15,23 @@ from __future__ import annotations import asyncio +import base64 +import json import logging import sys +from typing import Any +from typing import Awaitable from typing import Callable from typing import Dict from typing import List from typing import Optional from typing import TextIO +from typing import TypeVar from typing import Union import warnings from mcp import StdioServerParameters +from mcp.types import ListResourcesResult from mcp.types import ListToolsResult from pydantic import model_validator from typing_extensions import override @@ -48,6 +54,9 @@ from .mcp_tool import MCPTool logger = logging.getLogger("google_adk." + __name__) +T = TypeVar("T") + + class McpToolset(BaseToolset): """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. @@ -140,6 +149,31 @@ class McpToolset(BaseToolset): self._auth_credential = auth_credential self._require_confirmation = require_confirmation + async def _execute_with_session( + self, + coroutine_func: Callable[[Any], Awaitable[T]], + error_message: str, + readonly_context: Optional[ReadonlyContext] = None, + ) -> T: + """Creates a session and executes a coroutine with it.""" + headers = ( + self._header_provider(readonly_context) + if self._header_provider and readonly_context + else None + ) + session = await self._mcp_session_manager.create_session(headers=headers) + timeout_in_seconds = ( + self._connection_params.timeout + if hasattr(self._connection_params, "timeout") + else None + ) + try: + return await asyncio.wait_for( + coroutine_func(session), timeout=timeout_in_seconds + ) + except Exception as e: + raise ConnectionError(f"{error_message}: {e}") from e + @retry_on_errors async def get_tools( self, @@ -154,26 +188,12 @@ class McpToolset(BaseToolset): Returns: List[BaseTool]: A list of tools available under the specified context. """ - headers = ( - self._header_provider(readonly_context) - if self._header_provider and readonly_context - else None - ) - # Get session from session manager - session = await self._mcp_session_manager.create_session(headers=headers) - # Fetch available tools from the MCP server - timeout_in_seconds = ( - self._connection_params.timeout - if hasattr(self._connection_params, "timeout") - else None + tools_response: ListToolsResult = await self._execute_with_session( + lambda session: session.list_tools(), + "Failed to get tools from MCP server", + readonly_context, ) - try: - tools_response: ListToolsResult = await asyncio.wait_for( - session.list_tools(), timeout=timeout_in_seconds - ) - except Exception as e: - raise ConnectionError(f"Failed to get tools from MCP server: {e}") from e # Apply filtering based on context and tool_filter tools = [] @@ -191,6 +211,66 @@ class McpToolset(BaseToolset): tools.append(mcp_tool) return tools + async def read_resource( + self, name: str, readonly_context: Optional[ReadonlyContext] = None + ) -> Any: + """Fetches and returns the content of the named resource. + + This method will handle content decoding based on the MIME type reported by + the MCP server (e.g., JSON, text, base64 for binary). + + Args: + name: The name of the resource to fetch. + readonly_context: Context used to provide headers for the MCP session. + + Returns: + The content of the resource, decoded based on MIME type and encoding. + """ + result: Any = await self._execute_with_session( + lambda session: session.get_resource(name=name), + f"Failed to get resource {name} from MCP server", + readonly_context, + ) + + content = result.content + if result.encoding == "base64": + decoded_bytes = base64.b64decode(content) + if result.resource.mime_type == "application/json": + return json.loads(decoded_bytes.decode("utf-8")) + if result.resource.mime_type.startswith("text/"): + return decoded_bytes.decode("utf-8") + return decoded_bytes # Return as bytes for other binary types + + if result.resource.mime_type == "application/json": + return json.loads(content) + + return content + + async def list_resources( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> list[str]: + """Returns a list of resource names available on the MCP server.""" + result: ListResourcesResult = await self._execute_with_session( + lambda session: session.list_resources(), + "Failed to list resources from MCP server", + readonly_context, + ) + return [resource.name for resource in result.resources] + + async def get_resource_info( + self, name: str, readonly_context: Optional[ReadonlyContext] = None + ) -> dict[str, Any]: + """Returns metadata about a specific resource (name, MIME type, etc.).""" + result: ListResourcesResult = await self._execute_with_session( + lambda session: session.list_resources(), + "Failed to list resources from MCP server", + readonly_context, + ) + for resource in result.resources: + if resource.name == name: + return resource.model_dump(mode="json", exclude_none=True) + raise ValueError(f"Resource with name '{name}' not found.") + async def close(self) -> None: """Performs cleanup and releases resources held by the toolset. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 83b112d8..d2fece0e 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -13,7 +13,9 @@ # limitations under the License. import asyncio +import base64 from io import StringIO +import json import sys import unittest from unittest.mock import AsyncMock @@ -31,6 +33,8 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset from google.adk.tools.mcp_tool.mcp_toolset import McpToolset from mcp import StdioServerParameters +from mcp.types import ListResourcesResult +from mcp.types import Resource import pytest @@ -353,3 +357,135 @@ class TestMCPToolset: # Assert that the original tools are not modified assert tools[0].name == "tool1" assert tools[1].name == "tool2" + + @pytest.mark.asyncio + async def test_list_resources(self): + """Test listing resources.""" + resources = [ + Resource( + name="file1.txt", mime_type="text/plain", uri="file:///file1.txt" + ), + Resource( + name="data.json", + mime_type="application/json", + uri="file:///data.json", + ), + ] + list_resources_result = ListResourcesResult(resources=resources) + self.mock_session.list_resources = AsyncMock( + return_value=list_resources_result + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + result = await toolset.list_resources() + + assert result == ["file1.txt", "data.json"] + self.mock_session.list_resources.assert_called_once() + + @pytest.mark.asyncio + async def test_get_resource_info_success(self): + """Test getting resource info for an existing resource.""" + resources = [ + Resource( + name="file1.txt", mime_type="text/plain", uri="file:///file1.txt" + ), + Resource( + name="data.json", + mime_type="application/json", + uri="file:///data.json", + ), + ] + list_resources_result = ListResourcesResult(resources=resources) + self.mock_session.list_resources = AsyncMock( + return_value=list_resources_result + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + result = await toolset.get_resource_info("data.json") + + assert result == { + "name": "data.json", + "mime_type": "application/json", + "uri": "file:///data.json", + } + self.mock_session.list_resources.assert_called_once() + + @pytest.mark.asyncio + async def test_get_resource_info_not_found(self): + """Test getting resource info for a non-existent resource.""" + resources = [ + Resource( + name="file1.txt", mime_type="text/plain", uri="file:///file1.txt" + ), + ] + list_resources_result = ListResourcesResult(resources=resources) + self.mock_session.list_resources = AsyncMock( + return_value=list_resources_result + ) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + with pytest.raises( + ValueError, match="Resource with name 'other.json' not found." + ): + await toolset.get_resource_info("other.json") + + @pytest.mark.parametrize( + "name,mime_type,content,encoding,expected_result", + [ + ("file1.txt", "text/plain", "hello world", None, "hello world"), + ( + "data.json", + "application/json", + '{"key": "value"}', + None, + {"key": "value"}, + ), + ( + "file1_b64.txt", + "text/plain", + base64.b64encode(b"hello world").decode("ascii"), + "base64", + "hello world", + ), + ( + "data_b64.json", + "application/json", + base64.b64encode(b'{"key": "value"}').decode("ascii"), + "base64", + {"key": "value"}, + ), + ( + "data.bin", + "application/octet-stream", + base64.b64encode(b"\x01\x02\x03").decode("ascii"), + "base64", + b"\x01\x02\x03", + ), + ], + ) + @pytest.mark.asyncio + async def test_read_resource( + self, name, mime_type, content, encoding, expected_result + ): + """Test reading various resource types.""" + get_resource_result = MagicMock() + get_resource_result.resource = Resource( + name=name, mime_type=mime_type, uri=f"file:///{name}" + ) + get_resource_result.content = content + get_resource_result.encoding = encoding + self.mock_session.get_resource = AsyncMock(return_value=get_resource_result) + + toolset = MCPToolset(connection_params=self.mock_stdio_params) + toolset._mcp_session_manager = self.mock_session_manager + + result = await toolset.read_resource(name) + + assert result == expected_result + self.mock_session.get_resource.assert_called_once_with(name=name)