diff --git a/contributing/samples/api_registry_agent/README.md b/contributing/samples/api_registry_agent/README.md new file mode 100644 index 00000000..78b3c223 --- /dev/null +++ b/contributing/samples/api_registry_agent/README.md @@ -0,0 +1,21 @@ +# BigQuery API Registry Agent + +This agent demonstrates how to use `ApiRegistry` to discover and interact with Google Cloud services like BigQuery via tools exposed by an MCP server registered in an API Registry. + +## Prerequisites + +- A Google Cloud project with the API Registry API enabled. +- An MCP server exposing BigQuery tools registered in API Registry. + +## Configuration & Running + +1. **Configure:** Edit `agent.py` and replace `your-google-cloud-project-id` and `your-mcp-server-name` with your Google Cloud Project ID and the name of your registered MCP server. +2. **Run in CLI:** + ```bash + adk run contributing/samples/api_registry_agent -- --log-level DEBUG + ``` +3. **Run in Web UI:** + ```bash + adk web contributing/samples/ + ``` + Navigate to `http://127.0.0.1:8080` and select the `api_registry_agent` agent. diff --git a/contributing/samples/api_registry_agent/__init__.py b/contributing/samples/api_registry_agent/__init__.py new file mode 100644 index 00000000..c48963cd --- /dev/null +++ b/contributing/samples/api_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py new file mode 100644 index 00000000..65048220 --- /dev/null +++ b/contributing/samples/api_registry_agent/agent.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.api_registry import ApiRegistry + +# TODO: Fill in with your GCloud project id and MCP server name +PROJECT_ID = "your-google-cloud-project-id" +MCP_SERVER_NAME = "your-mcp-server-name" + +# Header required for BigQuery MCP server +header_provider = lambda context: { + "x-goog-user-project": PROJECT_ID, +} +api_registry = ApiRegistry(PROJECT_ID, header_provider=header_provider) +registry_tools = api_registry.get_toolset( + mcp_server_name=MCP_SERVER_NAME, +) +root_agent = LlmAgent( + model="gemini-2.0-flash", + name="bigquery_assistant", + instruction=""" +Help user access their BigQuery data via API Registry tools. + """, + tools=[registry_tools], +) diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index d359abb7..32264adc 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from ..auth.auth_tool import AuthToolArguments from .agent_tool import AgentTool + from .api_registry import ApiRegistry from .apihub_tool.apihub_toolset import APIHubToolset from .base_tool import BaseTool from .discovery_engine_search_tool import DiscoveryEngineSearchTool @@ -84,6 +85,7 @@ _LAZY_MAPPING = { 'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'), 'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'), 'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'), + 'ApiRegistry': ('.api_registry', 'ApiRegistry'), } __all__ = list(_LAZY_MAPPING.keys()) diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py new file mode 100644 index 00000000..941c6f0d --- /dev/null +++ b/src/google/adk/tools/api_registry.py @@ -0,0 +1,124 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import google.auth +import google.auth.transport.requests +import httpx + +from .base_toolset import ToolPredicate +from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from .mcp_tool.mcp_toolset import McpToolset + +# TODO(wukathy): Update to prod URL once it is available. +API_REGISTRY_URL = "https://staging-cloudapiregistry.sandbox.googleapis.com" + + +class ApiRegistry: + """Registry that provides McpToolsets for MCP servers registered in API Registry.""" + + def __init__( + self, + api_registry_project_id: str, + location: str = "global", + header_provider: Optional[ + Callable[[ReadonlyContext], Dict[str, str]] + ] = None, + ): + """Initialize the API Registry. + + Args: + api_registry_project_id: The project ID for the Google Cloud API Registry. + location: The location of the API Registry resources. + header_provider: Optional function to provide additional headers for MCP + server calls. + """ + self.api_registry_project_id = api_registry_project_id + self.location = location + self._credentials, _ = google.auth.default() + self._mcp_servers: Dict[str, Dict[str, Any]] = {} + self._header_provider = header_provider + + url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + "Content-Type": "application/json", + } + with httpx.Client() as client: + response = client.get(url, headers=headers) + response.raise_for_status() + mcp_servers_list = response.json().get("mcpServers", []) + for server in mcp_servers_list: + server_name = server.get("name", "") + if server_name: + self._mcp_servers[server_name] = server + except (httpx.HTTPError, ValueError) as e: + # Handle error in fetching or parsing tool definitions + raise RuntimeError( + f"Error fetching MCP servers from API Registry: {e}" + ) from e + + def get_toolset( + self, + mcp_server_name: str, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + tool_name_prefix: Optional[str] = None, + ) -> McpToolset: + """Return the MCP Toolset based on the params. + + Args: + mcp_server_name: Filter to select the MCP server name to get tools + from. + tool_filter: Optional filter to select specific tools. Can be a list of + tool names or a ToolPredicate function. + tool_name_prefix: Optional prefix to prepend to the names of the tools + returned by the toolset. + + Returns: + McpToolset: A toolset for the MCP server specified. + """ + server = self._mcp_servers.get(mcp_server_name) + if not server: + raise ValueError( + f"MCP server {mcp_server_name} not found in API Registry." + ) + if not server.get("urls"): + raise ValueError(f"MCP server {mcp_server_name} has no URLs.") + + mcp_server_url = server["urls"][0] + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + return McpToolset( + connection_params=StreamableHTTPConnectionParams( + url="https://" + mcp_server_url, + headers=headers, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=self._header_provider, + ) diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/tools/test_api_registry.py new file mode 100644 index 00000000..d1131eed --- /dev/null +++ b/tests/unittests/tools/test_api_registry.py @@ -0,0 +1,205 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.tools.api_registry import ApiRegistry +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +import httpx + +MOCK_MCP_SERVERS_LIST = { + "mcpServers": [ + { + "name": "test-mcp-server-1", + "urls": ["mcp.server1.com"], + }, + { + "name": "test-mcp-server-2", + "urls": ["mcp.server2.com"], + }, + { + "name": "test-mcp-server-no-url", + }, + ] +} + + +class TestApiRegistry(unittest.IsolatedAsyncioTestCase): + """Unit tests for ApiRegistry.""" + + def setUp(self): + self.project_id = "test-project" + self.location = "global" + self.mock_credentials = MagicMock() + self.mock_credentials.token = "mock_token" + self.mock_credentials.refresh = MagicMock() + mock_auth_patcher = patch( + "google.auth.default", + return_value=(self.mock_credentials, None), + autospec=True, + ) + mock_auth_patcher.start() + self.addCleanup(mock_auth_patcher.stop) + + @patch("httpx.Client", autospec=True) + def test_init_success(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + self.assertEqual(len(api_registry._mcp_servers), 3) + self.assertIn("test-mcp-server-1", api_registry._mcp_servers) + self.assertIn("test-mcp-server-2", api_registry._mcp_servers) + self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) + mock_client_instance.get.assert_called_once_with( + f"https://staging-cloudapiregistry.sandbox.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={ + "Authorization": "Bearer mock_token", + "Content-Type": "application/json", + }, + ) + + @patch("httpx.Client", autospec=True) + def test_init_http_error(self, MockHttpClient): + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.side_effect = httpx.RequestError( + "Connection failed" + ) + + with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): + ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + @patch("httpx.Client", autospec=True) + def test_init_bad_response(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=MagicMock() + ) + ) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): + ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + mock_response.raise_for_status.assert_called_once() + + @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch("httpx.Client", autospec=True) + async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + toolset = api_registry.get_toolset("test-mcp-server-1") + + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url="https://mcp.server1.com", + headers={"Authorization": "Bearer mock_token"}, + ), + tool_filter=None, + tool_name_prefix=None, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch("httpx.Client", autospec=True) + async def test_get_toolset_with_filter_and_prefix( + self, MockHttpClient, MockMcpToolset + ): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + tool_filter = ["tool1"] + tool_name_prefix = "prefix_" + toolset = api_registry.get_toolset( + "test-mcp-server-1", + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + ) + + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url="https://mcp.server1.com", + headers={"Authorization": "Bearer mock_token"}, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch("httpx.Client", autospec=True) + async def test_get_toolset_server_not_found(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + with self.assertRaisesRegex(ValueError, "not found in API Registry"): + api_registry.get_toolset("non-existent-server") + + @patch("httpx.Client", autospec=True) + async def test_get_toolset_server_no_url(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + with self.assertRaisesRegex(ValueError, "has no URLs"): + api_registry.get_toolset("test-mcp-server-no-url")