feat: Add a load MCP resource tool

If the user specifies use_mcp_resources=True in their MCPToolset, the agent will be able to load resources with the load_mcp_resource_tool.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 866539602
This commit is contained in:
Kathy Wu
2026-02-06 11:17:59 -08:00
committed by Copybara-Service
parent c7362100eb
commit e25227da5e
6 changed files with 393 additions and 5 deletions
@@ -52,6 +52,7 @@ Allowed directory: {_allowed_path}
'get_file_info',
'list_allowed_directories',
],
use_mcp_resources=True,
)
],
)
@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
import json
import os
from pathlib import Path
import sys
@@ -45,6 +46,24 @@ def get_cwd() -> str:
return str(Path.cwd())
# Add a resource for testing with JSON data
@mcp.resource(
name="sample_data",
uri="file:///sample_data.json",
mime_type="application/json",
)
def sample_data() -> str:
data = {
"users": [
{"id": 1, "name": "Alice", "role": "admin"},
{"id": 2, "name": "Bob", "role": "user"},
{"id": 3, "name": "Charlie", "role": "user"},
],
"settings": {"theme": "dark", "notifications": True},
}
return json.dumps(data, indent=2)
# Graceful shutdown handler
async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown."""
@@ -0,0 +1,170 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
import json
import logging
from typing import Any
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from ..features import FeatureName
from ..features import is_feature_enabled
from ..models.llm_request import LlmRequest
from .base_tool import BaseTool
if TYPE_CHECKING:
from mcp_toolset import McpToolset
from .tool_context import ToolContext
logger = logging.getLogger("google_adk." + __name__)
class LoadMcpResourceTool(BaseTool):
"""A tool that loads the MCP resources and adds them to the session."""
def __init__(self, mcp_toolset: McpToolset):
super().__init__(
name="load_mcp_resource",
description="""Loads resources from the MCP server.
NOTE: Call when you need access to resources.""",
)
self._mcp_toolset = mcp_toolset
def _get_declaration(self) -> types.FunctionDeclaration | None:
if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL):
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters_json_schema={
"type": "object",
"properties": {
"resource_names": {
"type": "array",
"items": {"type": "string"},
},
},
},
)
return types.FunctionDeclaration(
name=self.name,
description=self.description,
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"resource_names": types.Schema(
type=types.Type.ARRAY,
items=types.Schema(
type=types.Type.STRING,
),
)
},
),
)
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
resource_names: list[str] = args.get("resource_names", [])
return {
"resource_names": resource_names,
"status": (
"resource contents temporarily inserted and removed. to access"
" these resources, call load_mcp_resource tool again."
),
}
@override
async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
await super().process_llm_request(
tool_context=tool_context,
llm_request=llm_request,
)
await self._append_resources_to_llm_request(
tool_context=tool_context, llm_request=llm_request
)
async def _append_resources_to_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
):
try:
resource_names = await self._mcp_toolset.list_resources()
if resource_names:
llm_request.append_instructions([f"""You have a list of MCP resources:
{json.dumps(resource_names)}
When the user asks questions about any of the resources, you should call the
`load_mcp_resource` function to load the resource. Always call load_mcp_resource
before answering questions related to the resources.
"""])
except Exception as e:
logger.warning("Failed to list MCP resources: %s", e)
# Attach content
if llm_request.contents and llm_request.contents[-1].parts:
function_response = llm_request.contents[-1].parts[0].function_response
if function_response and function_response.name == self.name:
response = function_response.response or {}
resource_names = response.get("resource_names", [])
for resource_name in resource_names:
try:
contents = await self._mcp_toolset.read_resource(resource_name)
for content in contents:
part = self._mcp_content_to_part(content, resource_name)
llm_request.contents.append(
types.Content(
role="user",
parts=[
types.Part.from_text(
text=f"Resource {resource_name} is:"
),
part,
],
)
)
except Exception as e:
logger.warning(
"Failed to read MCP resource '%s': %s", resource_name, e
)
continue
def _mcp_content_to_part(
self, content: Any, resource_name: str
) -> types.Part:
if hasattr(content, "text") and content.text is not None:
return types.Part.from_text(text=content.text)
elif hasattr(content, "blob") and content.blob is not None:
try:
data = base64.b64decode(content.blob)
# Basic check for mime type or default
mime_type = content.mimeType or "application/octet-stream"
return types.Part.from_bytes(data=data, mime_type=mime_type)
except Exception:
return types.Part.from_text(
text=f"[Binary content for {resource_name} could not be decoded]"
)
else:
return types.Part.from_text(
text=f"[Unknown content type for {resource_name}]"
)
+18 -1
View File
@@ -16,7 +16,6 @@ from __future__ import annotations
import asyncio
import base64
import json
import logging
import sys
from typing import Any
@@ -44,8 +43,10 @@ from ...auth.auth_tool import AuthConfig
from ..base_tool import BaseTool
from ..base_toolset import BaseToolset
from ..base_toolset import ToolPredicate
from ..load_mcp_resource_tool import LoadMcpResourceTool
from ..tool_configs import BaseToolConfig
from ..tool_configs import ToolArgsConfig
from ..tool_context import ToolContext
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_errors
from .mcp_session_manager import SseConnectionParams
@@ -112,6 +113,7 @@ class McpToolset(BaseToolset):
progress_callback: Optional[
Union[ProgressFnT, ProgressCallbackFactory]
] = None,
use_mcp_resources: Optional[bool] = False,
):
"""Initializes the McpToolset.
@@ -148,7 +150,11 @@ class McpToolset(BaseToolset):
progress handling logic and access/modify session state via the
CallbackContext. The **kwargs parameter allows for future
extensibility.
use_mcp_resources: Whether the agent should have access to MCP resources.
This will add a `load_mcp_resource` tool to the toolset and include
available resources in the agent context. Defaults to False.
"""
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
if not connection_params:
@@ -177,6 +183,7 @@ class McpToolset(BaseToolset):
if auth_scheme
else None
)
self._use_mcp_resources = use_mcp_resources
def _get_auth_headers(self) -> Optional[Dict[str, str]]:
"""Build authentication headers from exchanged credential.
@@ -317,6 +324,13 @@ class McpToolset(BaseToolset):
if self._is_tool_selected(mcp_tool, readonly_context):
tools.append(mcp_tool)
if self._use_mcp_resources:
load_resource_tool = LoadMcpResourceTool(
mcp_toolset=self,
)
tools.append(load_resource_tool)
return tools
async def read_resource(
@@ -415,6 +429,7 @@ class McpToolset(BaseToolset):
tool_name_prefix=mcp_toolset_config.tool_name_prefix,
auth_scheme=mcp_toolset_config.auth_scheme,
auth_credential=mcp_toolset_config.auth_credential,
use_mcp_resources=mcp_toolset_config.use_mcp_resources,
)
@@ -451,6 +466,8 @@ class McpToolsetConfig(BaseToolConfig):
auth_credential: Optional[AuthCredential] = None
use_mcp_resources: bool = False
@model_validator(mode="after")
def _check_only_one_params_field(self):
param_fields = [
@@ -25,6 +25,7 @@ from unittest.mock import patch
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.auth.auth_credential import AuthCredential
from google.adk.tools.load_mcp_resource_tool import LoadMcpResourceTool
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
@@ -81,6 +82,14 @@ class TestMcpToolset:
assert toolset._errlog == sys.stderr
assert toolset._auth_scheme is None
assert toolset._auth_credential is None
assert toolset._use_mcp_resources is False
def test_init_with_use_mcp_resources(self):
"""Test initialization with use_mcp_resources."""
toolset = McpToolset(
connection_params=self.mock_stdio_params, use_mcp_resources=True
)
assert toolset._use_mcp_resources is True
def test_init_with_stdio_connection_params(self):
"""Test initialization with StdioConnectionParams."""
@@ -161,17 +170,21 @@ class TestMcpToolset:
return_value=MockListToolsResult(mock_tools)
)
toolset = McpToolset(connection_params=self.mock_stdio_params)
toolset = McpToolset(
connection_params=self.mock_stdio_params, use_mcp_resources=True
)
toolset._mcp_session_manager = self.mock_session_manager
tools = await toolset.get_tools()
assert len(tools) == 3
for tool in tools:
assert len(tools) == 4
for tool in tools[:3]:
assert isinstance(tool, MCPTool)
assert isinstance(tools[3], LoadMcpResourceTool)
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"
assert tools[2].name == "tool3"
assert tools[3].name == "load_mcp_resource"
@pytest.mark.asyncio
async def test_get_tools_with_list_filter(self):
@@ -338,6 +351,7 @@ class TestMcpToolset:
toolset = McpToolset(
connection_params=mock_connection_params,
tool_name_prefix="my_prefix",
use_mcp_resources=True,
)
# Replace the internal session manager with our mock
@@ -352,13 +366,15 @@ class TestMcpToolset:
prefixed_tools = await toolset.get_tools_with_prefix()
# Assert that the tools are prefixed correctly
assert len(prefixed_tools) == 2
assert len(prefixed_tools) == 3
assert prefixed_tools[0].name == "my_prefix_tool1"
assert prefixed_tools[1].name == "my_prefix_tool2"
assert prefixed_tools[2].name == "my_prefix_load_mcp_resource"
# Assert that the original tools are not modified
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"
assert tools[2].name == "load_mcp_resource"
def test_init_with_progress_callback(self):
"""Test initialization with progress_callback."""
@@ -0,0 +1,165 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
from google.adk.models.llm_request import LlmRequest
from google.adk.tools.load_mcp_resource_tool import LoadMcpResourceTool
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from google.adk.tools.tool_context import ToolContext
from google.genai import types
from mcp.types import BlobResourceContents
from mcp.types import TextResourceContents
import pytest
class TestLoadMcpResourceTool:
"""Test suite for LoadMcpResourceTool class."""
def setup_method(self):
"""Set up test fixtures."""
self.mock_mcp_toolset = Mock(spec=McpToolset)
self.mock_tool_context = Mock(spec=ToolContext)
def test_init(self):
"""Test initialization."""
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
assert tool.name == "load_mcp_resource"
assert tool._mcp_toolset == self.mock_mcp_toolset
@pytest.mark.asyncio
async def test_run_async(self):
"""Test run_async method."""
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
args = {"resource_names": ["res1", "res2"]}
result = await tool.run_async(
args=args, tool_context=self.mock_tool_context
)
assert result["resource_names"] == ["res1", "res2"]
assert "temporarily inserted" in result["status"]
def test_get_declaration(self):
"""Test _get_declaration method."""
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
declaration = tool._get_declaration()
assert isinstance(declaration, types.FunctionDeclaration)
assert declaration.name == "load_mcp_resource"
# Basic schema check, precise structure depends on is_feature_enabled
# and implementation details which might vary.
@pytest.mark.asyncio
async def test_process_llm_request_injects_list(self):
"""Test that resource list is injected when enabled."""
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
llm_request = Mock(spec=LlmRequest)
llm_request.contents = []
# Mock list_resources
self.mock_mcp_toolset.list_resources = AsyncMock(
return_value=["res1", "res2"]
)
await tool.process_llm_request(
tool_context=self.mock_tool_context, llm_request=llm_request
)
llm_request.append_instructions.assert_called_once()
instructions = llm_request.append_instructions.call_args[0][0]
assert "res1" in instructions[0]
assert "res2" in instructions[0]
async def test_process_llm_request_loads_content_text(self):
"""Test loading text resource content."""
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
llm_request = Mock(spec=LlmRequest)
llm_request.contents = []
# Setup LLM request with function call response asking for "res1"
function_response = Mock()
function_response.name = "load_mcp_resource"
function_response.response = {"resource_names": ["res1"]}
part = Mock()
part.function_response = function_response
content = Mock()
content.parts = [part]
llm_request.contents = [content]
# Mock read_resource
text_content = TextResourceContents(
uri="file:///res1", mimeType="text/plain", text="hello content"
)
self.mock_mcp_toolset.read_resource = AsyncMock(return_value=[text_content])
await tool.process_llm_request(
tool_context=self.mock_tool_context, llm_request=llm_request
)
# Verify content was appended
assert len(llm_request.contents) == 2 # Original + new content
new_content = llm_request.contents[1]
assert new_content.role == "user"
assert len(new_content.parts) == 2
assert "Resource res1 is:" in new_content.parts[0].text
assert new_content.parts[1].text == "hello content"
@pytest.mark.asyncio
async def test_process_llm_request_loads_content_binary(self):
"""Test loading binary resource content."""
tool = LoadMcpResourceTool(mcp_toolset=self.mock_mcp_toolset)
llm_request = Mock(spec=LlmRequest)
llm_request.contents = []
# Setup LLM request with function call response asking for "res1"
function_response = Mock()
function_response.name = "load_mcp_resource"
function_response.response = {"resource_names": ["res1"]}
part = Mock()
part.function_response = function_response
content = Mock()
content.parts = [part]
llm_request.contents = [content]
# Mock read_resource
blob_data = b"binary data"
blob_b64 = base64.b64encode(blob_data).decode("ascii")
blob_content = BlobResourceContents(
uri="file:///res1", mimeType="image/png", blob=blob_b64
)
self.mock_mcp_toolset.read_resource = AsyncMock(return_value=[blob_content])
await tool.process_llm_request(
tool_context=self.mock_tool_context, llm_request=llm_request
)
# Verify content was appended
assert len(llm_request.contents) == 2
new_content = llm_request.contents[1]
# Check that the second part is bytes
# Note: google.genai.types.Part.from_bytes creates a Part with inline_data
# Accessing it depends on the Part implementation.
# Since we are using real types.Part (not mocked), we can check attributes.
part = new_content.parts[1]
assert part.inline_data is not None
assert part.inline_data.mime_type == "image/png"
assert part.inline_data.data == blob_data