feat: Create APIRegistryToolset to add tools from Cloud API registry to agent

This calls the cloudapiregistry.googleapis.com API to get MCP tools from the project's registry, and adds them to ADK.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 837166909
This commit is contained in:
Kathy Wu
2025-11-26 10:01:22 -08:00
committed by Copybara-Service
parent f283027e92
commit ec4ccd718f
6 changed files with 406 additions and 0 deletions
@@ -0,0 +1,21 @@
# BigQuery API Registry Agent
This agent demonstrates how to use `ApiRegistry` to discover and interact with Google Cloud services like BigQuery via tools exposed by an MCP server registered in an API Registry.
## Prerequisites
- A Google Cloud project with the API Registry API enabled.
- An MCP server exposing BigQuery tools registered in API Registry.
## Configuration & Running
1. **Configure:** Edit `agent.py` and replace `your-google-cloud-project-id` and `your-mcp-server-name` with your Google Cloud Project ID and the name of your registered MCP server.
2. **Run in CLI:**
```bash
adk run contributing/samples/api_registry_agent -- --log-level DEBUG
```
3. **Run in Web UI:**
```bash
adk web contributing/samples/
```
Navigate to `http://127.0.0.1:8080` and select the `api_registry_agent` agent.
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,39 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.api_registry import ApiRegistry
# TODO: Fill in with your GCloud project id and MCP server name
PROJECT_ID = "your-google-cloud-project-id"
MCP_SERVER_NAME = "your-mcp-server-name"
# Header required for BigQuery MCP server
header_provider = lambda context: {
"x-goog-user-project": PROJECT_ID,
}
api_registry = ApiRegistry(PROJECT_ID, header_provider=header_provider)
registry_tools = api_registry.get_toolset(
mcp_server_name=MCP_SERVER_NAME,
)
root_agent = LlmAgent(
model="gemini-2.0-flash",
name="bigquery_assistant",
instruction="""
Help user access their BigQuery data via API Registry tools.
""",
tools=[registry_tools],
)
+2
View File
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..auth.auth_tool import AuthToolArguments
from .agent_tool import AgentTool
from .api_registry import ApiRegistry
from .apihub_tool.apihub_toolset import APIHubToolset
from .base_tool import BaseTool
from .discovery_engine_search_tool import DiscoveryEngineSearchTool
@@ -84,6 +85,7 @@ _LAZY_MAPPING = {
'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'),
'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'),
'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'),
'ApiRegistry': ('.api_registry', 'ApiRegistry'),
}
__all__ = list(_LAZY_MAPPING.keys())
+124
View File
@@ -0,0 +1,124 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import sys
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import google.auth
import google.auth.transport.requests
import httpx
from .base_toolset import ToolPredicate
from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
from .mcp_tool.mcp_toolset import McpToolset
# TODO(wukathy): Update to prod URL once it is available.
API_REGISTRY_URL = "https://staging-cloudapiregistry.sandbox.googleapis.com"
class ApiRegistry:
"""Registry that provides McpToolsets for MCP servers registered in API Registry."""
def __init__(
self,
api_registry_project_id: str,
location: str = "global",
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
):
"""Initialize the API Registry.
Args:
api_registry_project_id: The project ID for the Google Cloud API Registry.
location: The location of the API Registry resources.
header_provider: Optional function to provide additional headers for MCP
server calls.
"""
self.api_registry_project_id = api_registry_project_id
self.location = location
self._credentials, _ = google.auth.default()
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
self._header_provider = header_provider
url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers"
try:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
headers = {
"Authorization": f"Bearer {self._credentials.token}",
"Content-Type": "application/json",
}
with httpx.Client() as client:
response = client.get(url, headers=headers)
response.raise_for_status()
mcp_servers_list = response.json().get("mcpServers", [])
for server in mcp_servers_list:
server_name = server.get("name", "")
if server_name:
self._mcp_servers[server_name] = server
except (httpx.HTTPError, ValueError) as e:
# Handle error in fetching or parsing tool definitions
raise RuntimeError(
f"Error fetching MCP servers from API Registry: {e}"
) from e
def get_toolset(
self,
mcp_server_name: str,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
tool_name_prefix: Optional[str] = None,
) -> McpToolset:
"""Return the MCP Toolset based on the params.
Args:
mcp_server_name: Filter to select the MCP server name to get tools
from.
tool_filter: Optional filter to select specific tools. Can be a list of
tool names or a ToolPredicate function.
tool_name_prefix: Optional prefix to prepend to the names of the tools
returned by the toolset.
Returns:
McpToolset: A toolset for the MCP server specified.
"""
server = self._mcp_servers.get(mcp_server_name)
if not server:
raise ValueError(
f"MCP server {mcp_server_name} not found in API Registry."
)
if not server.get("urls"):
raise ValueError(f"MCP server {mcp_server_name} has no URLs.")
mcp_server_url = server["urls"][0]
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
headers = {
"Authorization": f"Bearer {self._credentials.token}",
}
return McpToolset(
connection_params=StreamableHTTPConnectionParams(
url="https://" + mcp_server_url,
headers=headers,
),
tool_filter=tool_filter,
tool_name_prefix=tool_name_prefix,
header_provider=self._header_provider,
)
+205
View File
@@ -0,0 +1,205 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch
from google.adk.tools.api_registry import ApiRegistry
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
import httpx
MOCK_MCP_SERVERS_LIST = {
"mcpServers": [
{
"name": "test-mcp-server-1",
"urls": ["mcp.server1.com"],
},
{
"name": "test-mcp-server-2",
"urls": ["mcp.server2.com"],
},
{
"name": "test-mcp-server-no-url",
},
]
}
class TestApiRegistry(unittest.IsolatedAsyncioTestCase):
"""Unit tests for ApiRegistry."""
def setUp(self):
self.project_id = "test-project"
self.location = "global"
self.mock_credentials = MagicMock()
self.mock_credentials.token = "mock_token"
self.mock_credentials.refresh = MagicMock()
mock_auth_patcher = patch(
"google.auth.default",
return_value=(self.mock_credentials, None),
autospec=True,
)
mock_auth_patcher.start()
self.addCleanup(mock_auth_patcher.stop)
@patch("httpx.Client", autospec=True)
def test_init_success(self, MockHttpClient):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.return_value = mock_response
api_registry = ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
self.assertEqual(len(api_registry._mcp_servers), 3)
self.assertIn("test-mcp-server-1", api_registry._mcp_servers)
self.assertIn("test-mcp-server-2", api_registry._mcp_servers)
self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers)
mock_client_instance.get.assert_called_once_with(
f"https://staging-cloudapiregistry.sandbox.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers",
headers={
"Authorization": "Bearer mock_token",
"Content-Type": "application/json",
},
)
@patch("httpx.Client", autospec=True)
def test_init_http_error(self, MockHttpClient):
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.side_effect = httpx.RequestError(
"Connection failed"
)
with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"):
ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
@patch("httpx.Client", autospec=True)
def test_init_bad_response(self, MockHttpClient):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock(
side_effect=httpx.HTTPStatusError(
"Not Found", request=MagicMock(), response=MagicMock()
)
)
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.return_value = mock_response
with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"):
ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
mock_response.raise_for_status.assert_called_once()
@patch("google.adk.tools.api_registry.McpToolset", autospec=True)
@patch("httpx.Client", autospec=True)
async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.return_value = mock_response
api_registry = ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
toolset = api_registry.get_toolset("test-mcp-server-1")
MockMcpToolset.assert_called_once_with(
connection_params=StreamableHTTPConnectionParams(
url="https://mcp.server1.com",
headers={"Authorization": "Bearer mock_token"},
),
tool_filter=None,
tool_name_prefix=None,
header_provider=None,
)
self.assertEqual(toolset, MockMcpToolset.return_value)
@patch("google.adk.tools.api_registry.McpToolset", autospec=True)
@patch("httpx.Client", autospec=True)
async def test_get_toolset_with_filter_and_prefix(
self, MockHttpClient, MockMcpToolset
):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.return_value = mock_response
api_registry = ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
tool_filter = ["tool1"]
tool_name_prefix = "prefix_"
toolset = api_registry.get_toolset(
"test-mcp-server-1",
tool_filter=tool_filter,
tool_name_prefix=tool_name_prefix,
)
MockMcpToolset.assert_called_once_with(
connection_params=StreamableHTTPConnectionParams(
url="https://mcp.server1.com",
headers={"Authorization": "Bearer mock_token"},
),
tool_filter=tool_filter,
tool_name_prefix=tool_name_prefix,
header_provider=None,
)
self.assertEqual(toolset, MockMcpToolset.return_value)
@patch("httpx.Client", autospec=True)
async def test_get_toolset_server_not_found(self, MockHttpClient):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.return_value = mock_response
api_registry = ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
with self.assertRaisesRegex(ValueError, "not found in API Registry"):
api_registry.get_toolset("non-existent-server")
@patch("httpx.Client", autospec=True)
async def test_get_toolset_server_no_url(self, MockHttpClient):
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST)
mock_client_instance = MockHttpClient.return_value
mock_client_instance.__enter__.return_value = mock_client_instance
mock_client_instance.get.return_value = mock_response
api_registry = ApiRegistry(
api_registry_project_id=self.project_id, location=self.location
)
with self.assertRaisesRegex(ValueError, "has no URLs"):
api_registry.get_toolset("test-mcp-server-no-url")