You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add methods in MCPToolset for users to access MCP resources (read_resource, list_resources, get_resource_info)
This allows users to access MCP resources on their own within agent logic / using custom tools. I plan on also later adding it to the agent state. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 861910520
This commit is contained in:
committed by
Copybara-Service
parent
3480b3b82d
commit
8f7d9659cf
@@ -15,17 +15,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Awaitable
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TextIO
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.types import ListResourcesResult
|
||||
from mcp.types import ListToolsResult
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
@@ -48,6 +54,9 @@ from .mcp_tool import MCPTool
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class McpToolset(BaseToolset):
|
||||
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
||||
|
||||
@@ -140,6 +149,31 @@ class McpToolset(BaseToolset):
|
||||
self._auth_credential = auth_credential
|
||||
self._require_confirmation = require_confirmation
|
||||
|
||||
async def _execute_with_session(
|
||||
self,
|
||||
coroutine_func: Callable[[Any], Awaitable[T]],
|
||||
error_message: str,
|
||||
readonly_context: Optional[ReadonlyContext] = None,
|
||||
) -> T:
|
||||
"""Creates a session and executes a coroutine with it."""
|
||||
headers = (
|
||||
self._header_provider(readonly_context)
|
||||
if self._header_provider and readonly_context
|
||||
else None
|
||||
)
|
||||
session = await self._mcp_session_manager.create_session(headers=headers)
|
||||
timeout_in_seconds = (
|
||||
self._connection_params.timeout
|
||||
if hasattr(self._connection_params, "timeout")
|
||||
else None
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
coroutine_func(session), timeout=timeout_in_seconds
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"{error_message}: {e}") from e
|
||||
|
||||
@retry_on_errors
|
||||
async def get_tools(
|
||||
self,
|
||||
@@ -154,26 +188,12 @@ class McpToolset(BaseToolset):
|
||||
Returns:
|
||||
List[BaseTool]: A list of tools available under the specified context.
|
||||
"""
|
||||
headers = (
|
||||
self._header_provider(readonly_context)
|
||||
if self._header_provider and readonly_context
|
||||
else None
|
||||
)
|
||||
# Get session from session manager
|
||||
session = await self._mcp_session_manager.create_session(headers=headers)
|
||||
|
||||
# Fetch available tools from the MCP server
|
||||
timeout_in_seconds = (
|
||||
self._connection_params.timeout
|
||||
if hasattr(self._connection_params, "timeout")
|
||||
else None
|
||||
tools_response: ListToolsResult = await self._execute_with_session(
|
||||
lambda session: session.list_tools(),
|
||||
"Failed to get tools from MCP server",
|
||||
readonly_context,
|
||||
)
|
||||
try:
|
||||
tools_response: ListToolsResult = await asyncio.wait_for(
|
||||
session.list_tools(), timeout=timeout_in_seconds
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to get tools from MCP server: {e}") from e
|
||||
|
||||
# Apply filtering based on context and tool_filter
|
||||
tools = []
|
||||
@@ -191,6 +211,66 @@ class McpToolset(BaseToolset):
|
||||
tools.append(mcp_tool)
|
||||
return tools
|
||||
|
||||
async def read_resource(
|
||||
self, name: str, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> Any:
|
||||
"""Fetches and returns the content of the named resource.
|
||||
|
||||
This method will handle content decoding based on the MIME type reported by
|
||||
the MCP server (e.g., JSON, text, base64 for binary).
|
||||
|
||||
Args:
|
||||
name: The name of the resource to fetch.
|
||||
readonly_context: Context used to provide headers for the MCP session.
|
||||
|
||||
Returns:
|
||||
The content of the resource, decoded based on MIME type and encoding.
|
||||
"""
|
||||
result: Any = await self._execute_with_session(
|
||||
lambda session: session.get_resource(name=name),
|
||||
f"Failed to get resource {name} from MCP server",
|
||||
readonly_context,
|
||||
)
|
||||
|
||||
content = result.content
|
||||
if result.encoding == "base64":
|
||||
decoded_bytes = base64.b64decode(content)
|
||||
if result.resource.mime_type == "application/json":
|
||||
return json.loads(decoded_bytes.decode("utf-8"))
|
||||
if result.resource.mime_type.startswith("text/"):
|
||||
return decoded_bytes.decode("utf-8")
|
||||
return decoded_bytes # Return as bytes for other binary types
|
||||
|
||||
if result.resource.mime_type == "application/json":
|
||||
return json.loads(content)
|
||||
|
||||
return content
|
||||
|
||||
async def list_resources(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> list[str]:
|
||||
"""Returns a list of resource names available on the MCP server."""
|
||||
result: ListResourcesResult = await self._execute_with_session(
|
||||
lambda session: session.list_resources(),
|
||||
"Failed to list resources from MCP server",
|
||||
readonly_context,
|
||||
)
|
||||
return [resource.name for resource in result.resources]
|
||||
|
||||
async def get_resource_info(
|
||||
self, name: str, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> dict[str, Any]:
|
||||
"""Returns metadata about a specific resource (name, MIME type, etc.)."""
|
||||
result: ListResourcesResult = await self._execute_with_session(
|
||||
lambda session: session.list_resources(),
|
||||
"Failed to list resources from MCP server",
|
||||
readonly_context,
|
||||
)
|
||||
for resource in result.resources:
|
||||
if resource.name == name:
|
||||
return resource.model_dump(mode="json", exclude_none=True)
|
||||
raise ValueError(f"Resource with name '{name}' not found.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Performs cleanup and releases resources held by the toolset.
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from io import StringIO
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -31,6 +33,8 @@ from google.adk.tools.mcp_tool.mcp_tool import MCPTool
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.types import ListResourcesResult
|
||||
from mcp.types import Resource
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -353,3 +357,135 @@ class TestMCPToolset:
|
||||
# Assert that the original tools are not modified
|
||||
assert tools[0].name == "tool1"
|
||||
assert tools[1].name == "tool2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_resources(self):
|
||||
"""Test listing resources."""
|
||||
resources = [
|
||||
Resource(
|
||||
name="file1.txt", mime_type="text/plain", uri="file:///file1.txt"
|
||||
),
|
||||
Resource(
|
||||
name="data.json",
|
||||
mime_type="application/json",
|
||||
uri="file:///data.json",
|
||||
),
|
||||
]
|
||||
list_resources_result = ListResourcesResult(resources=resources)
|
||||
self.mock_session.list_resources = AsyncMock(
|
||||
return_value=list_resources_result
|
||||
)
|
||||
|
||||
toolset = MCPToolset(connection_params=self.mock_stdio_params)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
result = await toolset.list_resources()
|
||||
|
||||
assert result == ["file1.txt", "data.json"]
|
||||
self.mock_session.list_resources.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resource_info_success(self):
|
||||
"""Test getting resource info for an existing resource."""
|
||||
resources = [
|
||||
Resource(
|
||||
name="file1.txt", mime_type="text/plain", uri="file:///file1.txt"
|
||||
),
|
||||
Resource(
|
||||
name="data.json",
|
||||
mime_type="application/json",
|
||||
uri="file:///data.json",
|
||||
),
|
||||
]
|
||||
list_resources_result = ListResourcesResult(resources=resources)
|
||||
self.mock_session.list_resources = AsyncMock(
|
||||
return_value=list_resources_result
|
||||
)
|
||||
|
||||
toolset = MCPToolset(connection_params=self.mock_stdio_params)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
result = await toolset.get_resource_info("data.json")
|
||||
|
||||
assert result == {
|
||||
"name": "data.json",
|
||||
"mime_type": "application/json",
|
||||
"uri": "file:///data.json",
|
||||
}
|
||||
self.mock_session.list_resources.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resource_info_not_found(self):
|
||||
"""Test getting resource info for a non-existent resource."""
|
||||
resources = [
|
||||
Resource(
|
||||
name="file1.txt", mime_type="text/plain", uri="file:///file1.txt"
|
||||
),
|
||||
]
|
||||
list_resources_result = ListResourcesResult(resources=resources)
|
||||
self.mock_session.list_resources = AsyncMock(
|
||||
return_value=list_resources_result
|
||||
)
|
||||
|
||||
toolset = MCPToolset(connection_params=self.mock_stdio_params)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Resource with name 'other.json' not found."
|
||||
):
|
||||
await toolset.get_resource_info("other.json")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name,mime_type,content,encoding,expected_result",
|
||||
[
|
||||
("file1.txt", "text/plain", "hello world", None, "hello world"),
|
||||
(
|
||||
"data.json",
|
||||
"application/json",
|
||||
'{"key": "value"}',
|
||||
None,
|
||||
{"key": "value"},
|
||||
),
|
||||
(
|
||||
"file1_b64.txt",
|
||||
"text/plain",
|
||||
base64.b64encode(b"hello world").decode("ascii"),
|
||||
"base64",
|
||||
"hello world",
|
||||
),
|
||||
(
|
||||
"data_b64.json",
|
||||
"application/json",
|
||||
base64.b64encode(b'{"key": "value"}').decode("ascii"),
|
||||
"base64",
|
||||
{"key": "value"},
|
||||
),
|
||||
(
|
||||
"data.bin",
|
||||
"application/octet-stream",
|
||||
base64.b64encode(b"\x01\x02\x03").decode("ascii"),
|
||||
"base64",
|
||||
b"\x01\x02\x03",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_resource(
|
||||
self, name, mime_type, content, encoding, expected_result
|
||||
):
|
||||
"""Test reading various resource types."""
|
||||
get_resource_result = MagicMock()
|
||||
get_resource_result.resource = Resource(
|
||||
name=name, mime_type=mime_type, uri=f"file:///{name}"
|
||||
)
|
||||
get_resource_result.content = content
|
||||
get_resource_result.encoding = encoding
|
||||
self.mock_session.get_resource = AsyncMock(return_value=get_resource_result)
|
||||
|
||||
toolset = MCPToolset(connection_params=self.mock_stdio_params)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
result = await toolset.read_resource(name)
|
||||
|
||||
assert result == expected_result
|
||||
self.mock_session.get_resource.assert_called_once_with(name=name)
|
||||
|
||||
Reference in New Issue
Block a user