You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Create APIRegistryToolset to add tools from Cloud API registry to agent
This calls the cloudapiregistry.googleapis.com API to get MCP tools from the project's registry, and adds them to ADK. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 837166909
This commit is contained in:
committed by
Copybara-Service
parent
f283027e92
commit
ec4ccd718f
@@ -0,0 +1,21 @@
|
||||
# BigQuery API Registry Agent
|
||||
|
||||
This agent demonstrates how to use `ApiRegistry` to discover and interact with Google Cloud services like BigQuery via tools exposed by an MCP server registered in an API Registry.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- A Google Cloud project with the API Registry API enabled.
|
||||
- An MCP server exposing BigQuery tools registered in API Registry.
|
||||
|
||||
## Configuration & Running
|
||||
|
||||
1. **Configure:** Edit `agent.py` and replace `your-google-cloud-project-id` and `your-mcp-server-name` with your Google Cloud Project ID and the name of your registered MCP server.
|
||||
2. **Run in CLI:**
|
||||
```bash
|
||||
adk run contributing/samples/api_registry_agent -- --log-level DEBUG
|
||||
```
|
||||
3. **Run in Web UI:**
|
||||
```bash
|
||||
adk web contributing/samples/
|
||||
```
|
||||
Navigate to `http://127.0.0.1:8080` and select the `api_registry_agent` agent.
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.tools.api_registry import ApiRegistry
|
||||
|
||||
# TODO: Fill in with your GCloud project id and MCP server name
|
||||
PROJECT_ID = "your-google-cloud-project-id"
|
||||
MCP_SERVER_NAME = "your-mcp-server-name"
|
||||
|
||||
# Header required for BigQuery MCP server
|
||||
header_provider = lambda context: {
|
||||
"x-goog-user-project": PROJECT_ID,
|
||||
}
|
||||
api_registry = ApiRegistry(PROJECT_ID, header_provider=header_provider)
|
||||
registry_tools = api_registry.get_toolset(
|
||||
mcp_server_name=MCP_SERVER_NAME,
|
||||
)
|
||||
root_agent = LlmAgent(
|
||||
model="gemini-2.0-flash",
|
||||
name="bigquery_assistant",
|
||||
instruction="""
|
||||
Help user access their BigQuery data via API Registry tools.
|
||||
""",
|
||||
tools=[registry_tools],
|
||||
)
|
||||
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ..auth.auth_tool import AuthToolArguments
|
||||
from .agent_tool import AgentTool
|
||||
from .api_registry import ApiRegistry
|
||||
from .apihub_tool.apihub_toolset import APIHubToolset
|
||||
from .base_tool import BaseTool
|
||||
from .discovery_engine_search_tool import DiscoveryEngineSearchTool
|
||||
@@ -84,6 +85,7 @@ _LAZY_MAPPING = {
|
||||
'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'),
|
||||
'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'),
|
||||
'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'),
|
||||
'ApiRegistry': ('.api_registry', 'ApiRegistry'),
|
||||
}
|
||||
|
||||
__all__ = list(_LAZY_MAPPING.keys())
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
import httpx
|
||||
|
||||
from .base_toolset import ToolPredicate
|
||||
from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
|
||||
from .mcp_tool.mcp_toolset import McpToolset
|
||||
|
||||
# TODO(wukathy): Update to prod URL once it is available.
|
||||
API_REGISTRY_URL = "https://staging-cloudapiregistry.sandbox.googleapis.com"
|
||||
|
||||
|
||||
class ApiRegistry:
|
||||
"""Registry that provides McpToolsets for MCP servers registered in API Registry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_registry_project_id: str,
|
||||
location: str = "global",
|
||||
header_provider: Optional[
|
||||
Callable[[ReadonlyContext], Dict[str, str]]
|
||||
] = None,
|
||||
):
|
||||
"""Initialize the API Registry.
|
||||
|
||||
Args:
|
||||
api_registry_project_id: The project ID for the Google Cloud API Registry.
|
||||
location: The location of the API Registry resources.
|
||||
header_provider: Optional function to provide additional headers for MCP
|
||||
server calls.
|
||||
"""
|
||||
self.api_registry_project_id = api_registry_project_id
|
||||
self.location = location
|
||||
self._credentials, _ = google.auth.default()
|
||||
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
|
||||
self._header_provider = header_provider
|
||||
|
||||
url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers"
|
||||
try:
|
||||
request = google.auth.transport.requests.Request()
|
||||
self._credentials.refresh(request)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
with httpx.Client() as client:
|
||||
response = client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
mcp_servers_list = response.json().get("mcpServers", [])
|
||||
for server in mcp_servers_list:
|
||||
server_name = server.get("name", "")
|
||||
if server_name:
|
||||
self._mcp_servers[server_name] = server
|
||||
except (httpx.HTTPError, ValueError) as e:
|
||||
# Handle error in fetching or parsing tool definitions
|
||||
raise RuntimeError(
|
||||
f"Error fetching MCP servers from API Registry: {e}"
|
||||
) from e
|
||||
|
||||
def get_toolset(
|
||||
self,
|
||||
mcp_server_name: str,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
tool_name_prefix: Optional[str] = None,
|
||||
) -> McpToolset:
|
||||
"""Return the MCP Toolset based on the params.
|
||||
|
||||
Args:
|
||||
mcp_server_name: Filter to select the MCP server name to get tools
|
||||
from.
|
||||
tool_filter: Optional filter to select specific tools. Can be a list of
|
||||
tool names or a ToolPredicate function.
|
||||
tool_name_prefix: Optional prefix to prepend to the names of the tools
|
||||
returned by the toolset.
|
||||
|
||||
Returns:
|
||||
McpToolset: A toolset for the MCP server specified.
|
||||
"""
|
||||
server = self._mcp_servers.get(mcp_server_name)
|
||||
if not server:
|
||||
raise ValueError(
|
||||
f"MCP server {mcp_server_name} not found in API Registry."
|
||||
)
|
||||
if not server.get("urls"):
|
||||
raise ValueError(f"MCP server {mcp_server_name} has no URLs.")
|
||||
|
||||
mcp_server_url = server["urls"][0]
|
||||
request = google.auth.transport.requests.Request()
|
||||
self._credentials.refresh(request)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
}
|
||||
return McpToolset(
|
||||
connection_params=StreamableHTTPConnectionParams(
|
||||
url="https://" + mcp_server_url,
|
||||
headers=headers,
|
||||
),
|
||||
tool_filter=tool_filter,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
header_provider=self._header_provider,
|
||||
)
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.tools.api_registry import ApiRegistry
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
|
||||
import httpx
|
||||
|
||||
MOCK_MCP_SERVERS_LIST = {
|
||||
"mcpServers": [
|
||||
{
|
||||
"name": "test-mcp-server-1",
|
||||
"urls": ["mcp.server1.com"],
|
||||
},
|
||||
{
|
||||
"name": "test-mcp-server-2",
|
||||
"urls": ["mcp.server2.com"],
|
||||
},
|
||||
{
|
||||
"name": "test-mcp-server-no-url",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class TestApiRegistry(unittest.IsolatedAsyncioTestCase):
|
||||
"""Unit tests for ApiRegistry."""
|
||||
|
||||
def setUp(self):
|
||||
self.project_id = "test-project"
|
||||
self.location = "global"
|
||||
self.mock_credentials = MagicMock()
|
||||
self.mock_credentials.token = "mock_token"
|
||||
self.mock_credentials.refresh = MagicMock()
|
||||
mock_auth_patcher = patch(
|
||||
"google.auth.default",
|
||||
return_value=(self.mock_credentials, None),
|
||||
autospec=True,
|
||||
)
|
||||
mock_auth_patcher.start()
|
||||
self.addCleanup(mock_auth_patcher.stop)
|
||||
|
||||
@patch("httpx.Client", autospec=True)
|
||||
def test_init_success(self, MockHttpClient):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
|
||||
api_registry = ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
|
||||
self.assertEqual(len(api_registry._mcp_servers), 3)
|
||||
self.assertIn("test-mcp-server-1", api_registry._mcp_servers)
|
||||
self.assertIn("test-mcp-server-2", api_registry._mcp_servers)
|
||||
self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers)
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"https://staging-cloudapiregistry.sandbox.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers",
|
||||
headers={
|
||||
"Authorization": "Bearer mock_token",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("httpx.Client", autospec=True)
|
||||
def test_init_http_error(self, MockHttpClient):
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.side_effect = httpx.RequestError(
|
||||
"Connection failed"
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"):
|
||||
ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
|
||||
@patch("httpx.Client", autospec=True)
|
||||
def test_init_bad_response(self, MockHttpClient):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=MagicMock()
|
||||
)
|
||||
)
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"):
|
||||
ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
@patch("google.adk.tools.api_registry.McpToolset", autospec=True)
|
||||
@patch("httpx.Client", autospec=True)
|
||||
async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
|
||||
api_registry = ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
|
||||
toolset = api_registry.get_toolset("test-mcp-server-1")
|
||||
|
||||
MockMcpToolset.assert_called_once_with(
|
||||
connection_params=StreamableHTTPConnectionParams(
|
||||
url="https://mcp.server1.com",
|
||||
headers={"Authorization": "Bearer mock_token"},
|
||||
),
|
||||
tool_filter=None,
|
||||
tool_name_prefix=None,
|
||||
header_provider=None,
|
||||
)
|
||||
self.assertEqual(toolset, MockMcpToolset.return_value)
|
||||
|
||||
@patch("google.adk.tools.api_registry.McpToolset", autospec=True)
|
||||
@patch("httpx.Client", autospec=True)
|
||||
async def test_get_toolset_with_filter_and_prefix(
|
||||
self, MockHttpClient, MockMcpToolset
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
|
||||
api_registry = ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
tool_filter = ["tool1"]
|
||||
tool_name_prefix = "prefix_"
|
||||
toolset = api_registry.get_toolset(
|
||||
"test-mcp-server-1",
|
||||
tool_filter=tool_filter,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
|
||||
MockMcpToolset.assert_called_once_with(
|
||||
connection_params=StreamableHTTPConnectionParams(
|
||||
url="https://mcp.server1.com",
|
||||
headers={"Authorization": "Bearer mock_token"},
|
||||
),
|
||||
tool_filter=tool_filter,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
header_provider=None,
|
||||
)
|
||||
self.assertEqual(toolset, MockMcpToolset.return_value)
|
||||
|
||||
@patch("httpx.Client", autospec=True)
|
||||
async def test_get_toolset_server_not_found(self, MockHttpClient):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
|
||||
api_registry = ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "not found in API Registry"):
|
||||
api_registry.get_toolset("non-existent-server")
|
||||
|
||||
@patch("httpx.Client", autospec=True)
|
||||
async def test_get_toolset_server_no_url(self, MockHttpClient):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
|
||||
mock_client_instance = MockHttpClient.return_value
|
||||
mock_client_instance.__enter__.return_value = mock_client_instance
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
|
||||
api_registry = ApiRegistry(
|
||||
api_registry_project_id=self.project_id, location=self.location
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "has no URLs"):
|
||||
api_registry.get_toolset("test-mcp-server-no-url")
|
||||
Reference in New Issue
Block a user