You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add a load MCP resource tool
If the user specifies use_mcp_resources=True in their MCPToolset, the agent will be able to load resources with the load_mcp_resource_tool. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 866539602
This commit is contained in:
committed by
Copybara-Service
parent
c7362100eb
commit
e25227da5e
@@ -52,6 +52,7 @@ Allowed directory: {_allowed_path}
|
||||
'get_file_info',
|
||||
'list_allowed_directories',
|
||||
],
|
||||
use_mcp_resources=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
@@ -45,6 +46,24 @@ def get_cwd() -> str:
|
||||
return str(Path.cwd())
|
||||
|
||||
|
||||
# Add a resource for testing with JSON data
|
||||
@mcp.resource(
|
||||
name="sample_data",
|
||||
uri="file:///sample_data.json",
|
||||
mime_type="application/json",
|
||||
)
|
||||
def sample_data() -> str:
|
||||
data = {
|
||||
"users": [
|
||||
{"id": 1, "name": "Alice", "role": "admin"},
|
||||
{"id": 2, "name": "Bob", "role": "user"},
|
||||
{"id": 3, "name": "Charlie", "role": "user"},
|
||||
],
|
||||
"settings": {"theme": "dark", "notifications": True},
|
||||
}
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
|
||||
# Graceful shutdown handler
|
||||
async def shutdown(signal, loop):
|
||||
"""Cleanup tasks tied to the service's shutdown."""
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..features import FeatureName
|
||||
from ..features import is_feature_enabled
|
||||
from ..models.llm_request import LlmRequest
|
||||
from .base_tool import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp_toolset import McpToolset
|
||||
|
||||
from .tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class LoadMcpResourceTool(BaseTool):
|
||||
"""A tool that loads the MCP resources and adds them to the session."""
|
||||
|
||||
def __init__(self, mcp_toolset: McpToolset):
|
||||
super().__init__(
|
||||
name="load_mcp_resource",
|
||||
description="""Loads resources from the MCP server.
|
||||
|
||||
NOTE: Call when you need access to resources.""",
|
||||
)
|
||||
self._mcp_toolset = mcp_toolset
|
||||
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters_json_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"resource_names": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"resource_names": types.Schema(
|
||||
type=types.Type.ARRAY,
|
||||
items=types.Schema(
|
||||
type=types.Type.STRING,
|
||||
),
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
resource_names: list[str] = args.get("resource_names", [])
|
||||
return {
|
||||
"resource_names": resource_names,
|
||||
"status": (
|
||||
"resource contents temporarily inserted and removed. to access"
|
||||
" these resources, call load_mcp_resource tool again."
|
||||
),
|
||||
}
|
||||
|
||||
@override
|
||||
async def process_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
await super().process_llm_request(
|
||||
tool_context=tool_context,
|
||||
llm_request=llm_request,
|
||||
)
|
||||
await self._append_resources_to_llm_request(
|
||||
tool_context=tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
async def _append_resources_to_llm_request(
|
||||
self, *, tool_context: ToolContext, llm_request: LlmRequest
|
||||
):
|
||||
try:
|
||||
resource_names = await self._mcp_toolset.list_resources()
|
||||
if resource_names:
|
||||
llm_request.append_instructions([f"""You have a list of MCP resources:
|
||||
{json.dumps(resource_names)}
|
||||
|
||||
When the user asks questions about any of the resources, you should call the
|
||||
`load_mcp_resource` function to load the resource. Always call load_mcp_resource
|
||||
before answering questions related to the resources.
|
||||
"""])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to list MCP resources: %s", e)
|
||||
|
||||
# Attach content
|
||||
if llm_request.contents and llm_request.contents[-1].parts:
|
||||
function_response = llm_request.contents[-1].parts[0].function_response
|
||||
if function_response and function_response.name == self.name:
|
||||
response = function_response.response or {}
|
||||
resource_names = response.get("resource_names", [])
|
||||
for resource_name in resource_names:
|
||||
try:
|
||||
contents = await self._mcp_toolset.read_resource(resource_name)
|
||||
|
||||
for content in contents:
|
||||
part = self._mcp_content_to_part(content, resource_name)
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f"Resource {resource_name} is:"
|
||||
),
|
||||
part,
|
||||
],
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to read MCP resource '%s': %s", resource_name, e
|
||||
)
|
||||
continue
|
||||
|
||||
def _mcp_content_to_part(
|
||||
self, content: Any, resource_name: str
|
||||
) -> types.Part:
|
||||
if hasattr(content, "text") and content.text is not None:
|
||||
return types.Part.from_text(text=content.text)
|
||||
elif hasattr(content, "blob") and content.blob is not None:
|
||||
try:
|
||||
data = base64.b64decode(content.blob)
|
||||
# Basic check for mime type or default
|
||||
mime_type = content.mimeType or "application/octet-stream"
|
||||
return types.Part.from_bytes(data=data, mime_type=mime_type)
|
||||
except Exception:
|
||||
return types.Part.from_text(
|
||||
text=f"[Binary content for {resource_name} could not be decoded]"
|
||||
)
|
||||
else:
|
||||
return types.Part.from_text(
|
||||
text=f"[Unknown content type for {resource_name}]"
|
||||
)
|
||||
@@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
@@ -44,8 +43,10 @@ from ...auth.auth_tool import AuthConfig
|
||||
from ..base_tool import BaseTool
|
||||
from ..base_toolset import BaseToolset
|
||||
from ..base_toolset import ToolPredicate
|
||||
from ..load_mcp_resource_tool import LoadMcpResourceTool
|
||||
from ..tool_configs import BaseToolConfig
|
||||
from ..tool_configs import ToolArgsConfig
|
||||
from ..tool_context import ToolContext
|
||||
from .mcp_session_manager import MCPSessionManager
|
||||
from .mcp_session_manager import retry_on_errors
|
||||
from .mcp_session_manager import SseConnectionParams
|
||||
@@ -112,6 +113,7 @@ class McpToolset(BaseToolset):
|
||||
progress_callback: Optional[
|
||||
Union[ProgressFnT, ProgressCallbackFactory]
|
||||
] = None,
|
||||
use_mcp_resources: Optional[bool] = False,
|
||||
):
|
||||
"""Initializes the McpToolset.
|
||||
|
||||
@@ -148,7 +150,11 @@ class McpToolset(BaseToolset):
|
||||
progress handling logic and access/modify session state via the
|
||||
CallbackContext. The **kwargs parameter allows for future
|
||||
extensibility.
|
||||
use_mcp_resources: Whether the agent should have access to MCP resources.
|
||||
This will add a `load_mcp_resource` tool to the toolset and include
|
||||
available resources in the agent context. Defaults to False.
|
||||
"""
|
||||
|
||||
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
|
||||
|
||||
if not connection_params:
|
||||
@@ -177,6 +183,7 @@ class McpToolset(BaseToolset):
|
||||
if auth_scheme
|
||||
else None
|
||||
)
|
||||
self._use_mcp_resources = use_mcp_resources
|
||||
|
||||
def _get_auth_headers(self) -> Optional[Dict[str, str]]:
|
||||
"""Build authentication headers from exchanged credential.
|
||||
@@ -317,6 +324,13 @@ class McpToolset(BaseToolset):
|
||||
|
||||
if self._is_tool_selected(mcp_tool, readonly_context):
|
||||
tools.append(mcp_tool)
|
||||
|
||||
if self._use_mcp_resources:
|
||||
load_resource_tool = LoadMcpResourceTool(
|
||||
mcp_toolset=self,
|
||||
)
|
||||
tools.append(load_resource_tool)
|
||||
|
||||
return tools
|
||||
|
||||
async def read_resource(
|
||||
@@ -415,6 +429,7 @@ class McpToolset(BaseToolset):
|
||||
tool_name_prefix=mcp_toolset_config.tool_name_prefix,
|
||||
auth_scheme=mcp_toolset_config.auth_scheme,
|
||||
auth_credential=mcp_toolset_config.auth_credential,
|
||||
use_mcp_resources=mcp_toolset_config.use_mcp_resources,
|
||||
)
|
||||
|
||||
|
||||
@@ -451,6 +466,8 @@ class McpToolsetConfig(BaseToolConfig):
|
||||
|
||||
auth_credential: Optional[AuthCredential] = None
|
||||
|
||||
use_mcp_resources: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_only_one_params_field(self):
|
||||
param_fields = [
|
||||
|
||||
@@ -25,6 +25,7 @@ from unittest.mock import patch
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.tools.load_mcp_resource_tool import LoadMcpResourceTool
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
|
||||
@@ -81,6 +82,14 @@ class TestMcpToolset:
|
||||
assert toolset._errlog == sys.stderr
|
||||
assert toolset._auth_scheme is None
|
||||
assert toolset._auth_credential is None
|
||||
assert toolset._use_mcp_resources is False
|
||||
|
||||
def test_init_with_use_mcp_resources(self):
|
||||
"""Test initialization with use_mcp_resources."""
|
||||
toolset = McpToolset(
|
||||
connection_params=self.mock_stdio_params, use_mcp_resources=True
|
||||
)
|
||||
assert toolset._use_mcp_resources is True
|
||||
|
||||
def test_init_with_stdio_connection_params(self):
|
||||
"""Test initialization with StdioConnectionParams."""
|
||||
@@ -161,17 +170,21 @@ class TestMcpToolset:
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
|
||||
toolset = McpToolset(connection_params=self.mock_stdio_params)
|
||||
toolset = McpToolset(
|
||||
connection_params=self.mock_stdio_params, use_mcp_resources=True
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
for tool in tools:
|
||||
assert len(tools) == 4
|
||||
for tool in tools[:3]:
|
||||
assert isinstance(tool, MCPTool)
|
||||
assert isinstance(tools[3], LoadMcpResourceTool)
|
||||
assert tools[0].name == "tool1"
|
||||
assert tools[1].name == "tool2"
|
||||
assert tools[2].name == "tool3"
|
||||
assert tools[3].name == "load_mcp_resource"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_with_list_filter(self):
|
||||
@@ -338,6 +351,7 @@ class TestMcpToolset:
|
||||
toolset = McpToolset(
|
||||
connection_params=mock_connection_params,
|
||||
tool_name_prefix="my_prefix",
|
||||
use_mcp_resources=True,
|
||||
)
|
||||
|
||||
# Replace the internal session manager with our mock
|
||||
@@ -352,13 +366,15 @@ class TestMcpToolset:
|
||||
prefixed_tools = await toolset.get_tools_with_prefix()
|
||||
|
||||
# Assert that the tools are prefixed correctly
|
||||
assert len(prefixed_tools) == 2
|
||||
assert len(prefixed_tools) == 3
|
||||
assert prefixed_tools[0].name == "my_prefix_tool1"
|
||||
assert prefixed_tools[1].name == "my_prefix_tool2"
|
||||
assert prefixed_tools[2].name == "my_prefix_load_mcp_resource"
|
||||
|
||||
# Assert that the original tools are not modified
|
||||
assert tools[0].name == "tool1"
|
||||
assert tools[1].name == "tool2"
|
||||
assert tools[2].name == "load_mcp_resource"
|
||||
|
||||
def test_init_with_progress_callback(self):
|
||||
"""Test initialization with progress_callback."""
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.tools.load_mcp_resource_tool import LoadMcpResourceTool
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
from mcp.types import BlobResourceContents
|
||||
from mcp.types import TextResourceContents
|
||||
import pytest
|
||||
|
||||
|
||||
class TestLoadMcpResourceTool:
|
||||
"""Test suite for LoadMcpResourceTool class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.mock_mcp_toolset = Mock(spec=McpToolset)
|
||||
self.mock_tool_context = Mock(spec=ToolContext)
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization."""
|
||||
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
|
||||
assert tool.name == "load_mcp_resource"
|
||||
assert tool._mcp_toolset == self.mock_mcp_toolset
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(self):
|
||||
"""Test run_async method."""
|
||||
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
|
||||
args = {"resource_names": ["res1", "res2"]}
|
||||
result = await tool.run_async(
|
||||
args=args, tool_context=self.mock_tool_context
|
||||
)
|
||||
|
||||
assert result["resource_names"] == ["res1", "res2"]
|
||||
assert "temporarily inserted" in result["status"]
|
||||
|
||||
def test_get_declaration(self):
|
||||
"""Test _get_declaration method."""
|
||||
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
|
||||
declaration = tool._get_declaration()
|
||||
|
||||
assert isinstance(declaration, types.FunctionDeclaration)
|
||||
assert declaration.name == "load_mcp_resource"
|
||||
# Basic schema check, precise structure depends on is_feature_enabled
|
||||
# and implementation details which might vary.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_injects_list(self):
|
||||
"""Test that resource list is injected when enabled."""
|
||||
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
|
||||
llm_request = Mock(spec=LlmRequest)
|
||||
llm_request.contents = []
|
||||
|
||||
# Mock list_resources
|
||||
self.mock_mcp_toolset.list_resources = AsyncMock(
|
||||
return_value=["res1", "res2"]
|
||||
)
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=self.mock_tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
llm_request.append_instructions.assert_called_once()
|
||||
instructions = llm_request.append_instructions.call_args[0][0]
|
||||
assert "res1" in instructions[0]
|
||||
assert "res2" in instructions[0]
|
||||
|
||||
async def test_process_llm_request_loads_content_text(self):
|
||||
"""Test loading text resource content."""
|
||||
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
|
||||
llm_request = Mock(spec=LlmRequest)
|
||||
llm_request.contents = []
|
||||
|
||||
# Setup LLM request with function call response asking for "res1"
|
||||
function_response = Mock()
|
||||
function_response.name = "load_mcp_resource"
|
||||
function_response.response = {"resource_names": ["res1"]}
|
||||
|
||||
part = Mock()
|
||||
part.function_response = function_response
|
||||
|
||||
content = Mock()
|
||||
content.parts = [part]
|
||||
llm_request.contents = [content]
|
||||
|
||||
# Mock read_resource
|
||||
text_content = TextResourceContents(
|
||||
uri="file:///res1", mimeType="text/plain", text="hello content"
|
||||
)
|
||||
self.mock_mcp_toolset.read_resource = AsyncMock(return_value=[text_content])
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=self.mock_tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
# Verify content was appended
|
||||
assert len(llm_request.contents) == 2 # Original + new content
|
||||
new_content = llm_request.contents[1]
|
||||
assert new_content.role == "user"
|
||||
assert len(new_content.parts) == 2
|
||||
assert "Resource res1 is:" in new_content.parts[0].text
|
||||
assert new_content.parts[1].text == "hello content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_llm_request_loads_content_binary(self):
|
||||
"""Test loading binary resource content."""
|
||||
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
|
||||
llm_request = Mock(spec=LlmRequest)
|
||||
llm_request.contents = []
|
||||
|
||||
# Setup LLM request with function call response asking for "res1"
|
||||
function_response = Mock()
|
||||
function_response.name = "load_mcp_resource"
|
||||
function_response.response = {"resource_names": ["res1"]}
|
||||
|
||||
part = Mock()
|
||||
part.function_response = function_response
|
||||
|
||||
content = Mock()
|
||||
content.parts = [part]
|
||||
llm_request.contents = [content]
|
||||
|
||||
# Mock read_resource
|
||||
blob_data = b"binary data"
|
||||
blob_b64 = base64.b64encode(blob_data).decode("ascii")
|
||||
blob_content = BlobResourceContents(
|
||||
uri="file:///res1", mimeType="image/png", blob=blob_b64
|
||||
)
|
||||
self.mock_mcp_toolset.read_resource = AsyncMock(return_value=[blob_content])
|
||||
|
||||
await tool.process_llm_request(
|
||||
tool_context=self.mock_tool_context, llm_request=llm_request
|
||||
)
|
||||
|
||||
# Verify content was appended
|
||||
assert len(llm_request.contents) == 2
|
||||
new_content = llm_request.contents[1]
|
||||
# Check that the second part is bytes
|
||||
# Note: google.genai.types.Part.from_bytes creates a Part with inline_data
|
||||
# Accessing it depends on the Part implementation.
|
||||
# Since we are using real types.Part (not mocked), we can check attributes.
|
||||
part = new_content.parts[1]
|
||||
assert part.inline_data is not None
|
||||
assert part.inline_data.mime_type == "image/png"
|
||||
assert part.inline_data.data == blob_data
|
||||
Reference in New Issue
Block a user