fix: nit fixes for API Registry - updating imports / type hints and formatting

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 845980315
This commit is contained in:
Kathy Wu
2025-12-17 17:23:26 -08:00
committed by Copybara-Service
parent 0088b0f3ad
commit 253ac87c79
+12 -15
View File
@@ -17,11 +17,8 @@ from __future__ import annotations
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
from google.adk.agents.readonly_context import ReadonlyContext
import google.auth
import google.auth.transport.requests
import httpx
@@ -40,9 +37,9 @@ class ApiRegistry:
self,
api_registry_project_id: str,
location: str = "global",
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
header_provider: (
Callable[[ReadonlyContext], dict[str, str]] | None
) = None,
):
"""Initialize the API Registry.
@@ -55,7 +52,7 @@ class ApiRegistry:
self.api_registry_project_id = api_registry_project_id
self.location = location
self._credentials, _ = google.auth.default()
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
self._mcp_servers: dict[str, dict[str, Any]] = {}
self._header_provider = header_provider
url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers"
@@ -79,14 +76,13 @@ class ApiRegistry:
def get_toolset(
self,
mcp_server_name: str,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
tool_name_prefix: Optional[str] = None,
tool_filter: ToolPredicate | list[str] | None = None,
tool_name_prefix: str | None = None,
) -> McpToolset:
"""Return the MCP Toolset based on the params.
Args:
mcp_server_name: Filter to select the MCP server name to get tools
from.
mcp_server_name: Filter to select the MCP server name to get tools from.
tool_filter: Optional filter to select specific tools. Can be a list of
tool names or a ToolPredicate function.
tool_name_prefix: Optional prefix to prepend to the names of the tools
@@ -116,7 +112,7 @@ class ApiRegistry:
header_provider=self._header_provider,
)
def _get_auth_headers(self) -> Dict[str, str]:
def _get_auth_headers(self) -> dict[str, str]:
"""Refreshes credentials and returns authorization headers."""
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
@@ -124,6 +120,7 @@ class ApiRegistry:
"Authorization": f"Bearer {self._credentials.token}",
}
# Add quota project header if available in ADC
if self._credentials.quota_project_id:
headers["x-goog-user-project"] = self._credentials.quota_project_id
quota_project_id = getattr(self._credentials, "quota_project_id", None)
if quota_project_id:
headers["x-goog-user-project"] = quota_project_id
return headers