feat: Add methods in MCPToolset for users to access MCP resources (read_resource, list_resources, get_resource_info)

This allows users to access MCP resources on their own within agent logic / using custom tools. I plan on also later adding it to the agent state.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 861910520
This commit is contained in:
Kathy Wu
2026-01-27 15:26:32 -08:00
committed by Copybara-Service
parent 3480b3b82d
commit 8f7d9659cf
2 changed files with 234 additions and 18 deletions
+98 -18
View File
@@ -15,17 +15,23 @@
from __future__ import annotations
import asyncio
import base64
import json
import logging
import sys
from typing import Any
from typing import Awaitable
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import TextIO
from typing import TypeVar
from typing import Union
import warnings
from mcp import StdioServerParameters
from mcp.types import ListResourcesResult
from mcp.types import ListToolsResult
from pydantic import model_validator
from typing_extensions import override
@@ -48,6 +54,9 @@ from .mcp_tool import MCPTool
logger = logging.getLogger("google_adk." + __name__)
T = TypeVar("T")
class McpToolset(BaseToolset):
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
@@ -140,6 +149,31 @@ class McpToolset(BaseToolset):
self._auth_credential = auth_credential
self._require_confirmation = require_confirmation
async def _execute_with_session(
self,
coroutine_func: Callable[[Any], Awaitable[T]],
error_message: str,
readonly_context: Optional[ReadonlyContext] = None,
) -> T:
"""Creates a session and executes a coroutine with it."""
headers = (
self._header_provider(readonly_context)
if self._header_provider and readonly_context
else None
)
session = await self._mcp_session_manager.create_session(headers=headers)
timeout_in_seconds = (
self._connection_params.timeout
if hasattr(self._connection_params, "timeout")
else None
)
try:
return await asyncio.wait_for(
coroutine_func(session), timeout=timeout_in_seconds
)
except Exception as e:
raise ConnectionError(f"{error_message}: {e}") from e
@retry_on_errors
async def get_tools(
self,
@@ -154,26 +188,12 @@ class McpToolset(BaseToolset):
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
headers = (
self._header_provider(readonly_context)
if self._header_provider and readonly_context
else None
)
# Get session from session manager
session = await self._mcp_session_manager.create_session(headers=headers)
# Fetch available tools from the MCP server
timeout_in_seconds = (
self._connection_params.timeout
if hasattr(self._connection_params, "timeout")
else None
tools_response: ListToolsResult = await self._execute_with_session(
lambda session: session.list_tools(),
"Failed to get tools from MCP server",
readonly_context,
)
try:
tools_response: ListToolsResult = await asyncio.wait_for(
session.list_tools(), timeout=timeout_in_seconds
)
except Exception as e:
raise ConnectionError(f"Failed to get tools from MCP server: {e}") from e
# Apply filtering based on context and tool_filter
tools = []
@@ -191,6 +211,66 @@ class McpToolset(BaseToolset):
tools.append(mcp_tool)
return tools
async def read_resource(
self, name: str, readonly_context: Optional[ReadonlyContext] = None
) -> Any:
"""Fetches and returns the content of the named resource.
This method will handle content decoding based on the MIME type reported by
the MCP server (e.g., JSON, text, base64 for binary).
Args:
name: The name of the resource to fetch.
readonly_context: Context used to provide headers for the MCP session.
Returns:
The content of the resource, decoded based on MIME type and encoding.
"""
result: Any = await self._execute_with_session(
lambda session: session.get_resource(name=name),
f"Failed to get resource {name} from MCP server",
readonly_context,
)
content = result.content
if result.encoding == "base64":
decoded_bytes = base64.b64decode(content)
if result.resource.mime_type == "application/json":
return json.loads(decoded_bytes.decode("utf-8"))
if result.resource.mime_type.startswith("text/"):
return decoded_bytes.decode("utf-8")
return decoded_bytes # Return as bytes for other binary types
if result.resource.mime_type == "application/json":
return json.loads(content)
return content
async def list_resources(
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[str]:
"""Returns a list of resource names available on the MCP server."""
result: ListResourcesResult = await self._execute_with_session(
lambda session: session.list_resources(),
"Failed to list resources from MCP server",
readonly_context,
)
return [resource.name for resource in result.resources]
async def get_resource_info(
self, name: str, readonly_context: Optional[ReadonlyContext] = None
) -> dict[str, Any]:
"""Returns metadata about a specific resource (name, MIME type, etc.)."""
result: ListResourcesResult = await self._execute_with_session(
lambda session: session.list_resources(),
"Failed to list resources from MCP server",
readonly_context,
)
for resource in result.resources:
if resource.name == name:
return resource.model_dump(mode="json", exclude_none=True)
raise ValueError(f"Resource with name '{name}' not found.")
async def close(self) -> None:
"""Performs cleanup and releases resources held by the toolset.
@@ -13,7 +13,9 @@
# limitations under the License.
import asyncio
import base64
from io import StringIO
import json
import sys
import unittest
from unittest.mock import AsyncMock
@@ -31,6 +33,8 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from mcp import StdioServerParameters
from mcp.types import ListResourcesResult
from mcp.types import Resource
import pytest
@@ -353,3 +357,135 @@ class TestMCPToolset:
# Assert that the original tools are not modified
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"
@pytest.mark.asyncio
async def test_list_resources(self):
"""Test listing resources."""
resources = [
Resource(
name="file1.txt", mime_type="text/plain", uri="file:///file1.txt"
),
Resource(
name="data.json",
mime_type="application/json",
uri="file:///data.json",
),
]
list_resources_result = ListResourcesResult(resources=resources)
self.mock_session.list_resources = AsyncMock(
return_value=list_resources_result
)
toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager
result = await toolset.list_resources()
assert result == ["file1.txt", "data.json"]
self.mock_session.list_resources.assert_called_once()
@pytest.mark.asyncio
async def test_get_resource_info_success(self):
"""Test getting resource info for an existing resource."""
resources = [
Resource(
name="file1.txt", mime_type="text/plain", uri="file:///file1.txt"
),
Resource(
name="data.json",
mime_type="application/json",
uri="file:///data.json",
),
]
list_resources_result = ListResourcesResult(resources=resources)
self.mock_session.list_resources = AsyncMock(
return_value=list_resources_result
)
toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager
result = await toolset.get_resource_info("data.json")
assert result == {
"name": "data.json",
"mime_type": "application/json",
"uri": "file:///data.json",
}
self.mock_session.list_resources.assert_called_once()
@pytest.mark.asyncio
async def test_get_resource_info_not_found(self):
"""Test getting resource info for a non-existent resource."""
resources = [
Resource(
name="file1.txt", mime_type="text/plain", uri="file:///file1.txt"
),
]
list_resources_result = ListResourcesResult(resources=resources)
self.mock_session.list_resources = AsyncMock(
return_value=list_resources_result
)
toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager
with pytest.raises(
ValueError, match="Resource with name 'other.json' not found."
):
await toolset.get_resource_info("other.json")
@pytest.mark.parametrize(
"name,mime_type,content,encoding,expected_result",
[
("file1.txt", "text/plain", "hello world", None, "hello world"),
(
"data.json",
"application/json",
'{"key": "value"}',
None,
{"key": "value"},
),
(
"file1_b64.txt",
"text/plain",
base64.b64encode(b"hello world").decode("ascii"),
"base64",
"hello world",
),
(
"data_b64.json",
"application/json",
base64.b64encode(b'{"key": "value"}').decode("ascii"),
"base64",
{"key": "value"},
),
(
"data.bin",
"application/octet-stream",
base64.b64encode(b"\x01\x02\x03").decode("ascii"),
"base64",
b"\x01\x02\x03",
),
],
)
@pytest.mark.asyncio
async def test_read_resource(
self, name, mime_type, content, encoding, expected_result
):
"""Test reading various resource types."""
get_resource_result = MagicMock()
get_resource_result.resource = Resource(
name=name, mime_type=mime_type, uri=f"file:///{name}"
)
get_resource_result.content = content
get_resource_result.encoding = encoding
self.mock_session.get_resource = AsyncMock(return_value=get_resource_result)
toolset = MCPToolset(connection_params=self.mock_stdio_params)
toolset._mcp_session_manager = self.mock_session_manager
result = await toolset.read_resource(name)
assert result == expected_result
self.mock_session.get_resource.assert_called_once_with(name=name)