You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Ensure consistent ADC quota project override in ADK
Fix discovery engine search tool, bigquery agent analytics plugin, and application integration tool to correctly handle the ADC quota project override -- the x-goog-user-project should be set based on the ADC quota project, per gcloud auth team's requirements. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 853841124
This commit is contained in:
committed by
Copybara-Service
parent
8afb99a078
commit
226e873b0f
@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING
|
||||
import uuid
|
||||
import weakref
|
||||
|
||||
from google.api_core import client_options
|
||||
from google.api_core.exceptions import InternalServerError
|
||||
from google.api_core.exceptions import ServiceUnavailable
|
||||
from google.api_core.exceptions import TooManyRequests
|
||||
@@ -1352,19 +1353,31 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
if _GLOBAL_WRITE_CLIENT is None:
|
||||
|
||||
def get_credentials():
|
||||
creds, _ = google.auth.default(
|
||||
creds, project_id = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
return creds
|
||||
return creds, project_id
|
||||
|
||||
creds = await loop.run_in_executor(self._executor, get_credentials)
|
||||
creds, project_id = await loop.run_in_executor(
|
||||
self._executor, get_credentials
|
||||
)
|
||||
quota_project_id = (
|
||||
getattr(creds, "quota_project_id", None) or project_id
|
||||
)
|
||||
options = (
|
||||
client_options.ClientOptions(quota_project_id=quota_project_id)
|
||||
if quota_project_id
|
||||
else None
|
||||
)
|
||||
client_info = gapic_client_info.ClientInfo(
|
||||
user_agent=f"google-adk-bq-logger/{__version__}"
|
||||
)
|
||||
# Initialize the async client in the current event loop, not in the
|
||||
# executor.
|
||||
_GLOBAL_WRITE_CLIENT = BigQueryWriteAsyncClient(
|
||||
credentials=creds, client_info=client_info
|
||||
credentials=creds,
|
||||
client_info=client_info,
|
||||
client_options=options,
|
||||
)
|
||||
self.write_client = _GLOBAL_WRITE_CLIENT
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@ class IntegrationClient:
|
||||
self.actions = actions if actions is not None else []
|
||||
self.service_account_json = service_account_json
|
||||
self.credential_cache = None
|
||||
self._quota_project_id = None
|
||||
|
||||
def get_openapi_spec_for_integration(self):
|
||||
"""Gets the OpenAPI spec for the integration.
|
||||
@@ -92,6 +93,8 @@ class IntegrationClient:
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._get_access_token()}",
|
||||
}
|
||||
if not self.service_account_json:
|
||||
headers["x-goog-user-project"] = self._quota_project_id or self.project
|
||||
data = {
|
||||
"apiTriggerResources": [
|
||||
{
|
||||
@@ -247,11 +250,14 @@ class IntegrationClient:
|
||||
)
|
||||
else:
|
||||
try:
|
||||
credentials, _ = default_service_credential(
|
||||
credentials, project_id = default_service_credential(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
except:
|
||||
credentials = None
|
||||
if credentials:
|
||||
quota_project_id = getattr(credentials, "quota_project_id", None)
|
||||
self._quota_project_id = quota_project_id or project_id
|
||||
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.api_core import client_options
|
||||
from google.api_core.exceptions import GoogleAPICallError
|
||||
import google.auth
|
||||
from google.cloud import discoveryengine_v1beta as discoveryengine
|
||||
@@ -72,8 +73,14 @@ class DiscoveryEngineSearchTool(FunctionTool):
|
||||
self._max_results = max_results
|
||||
|
||||
credentials, _ = google.auth.default()
|
||||
quota_project_id = getattr(credentials, "quota_project_id", None)
|
||||
options = (
|
||||
client_options.ClientOptions(quota_project_id=quota_project_id)
|
||||
if quota_project_id
|
||||
else None
|
||||
)
|
||||
self._discovery_engine_client = discoveryengine.SearchServiceClient(
|
||||
credentials=credentials
|
||||
credentials=credentials, client_options=options
|
||||
)
|
||||
|
||||
def discovery_engine_search(
|
||||
|
||||
@@ -1568,6 +1568,41 @@ class TestBigQueryAgentAnalyticsPlugin:
|
||||
await plugin2.shutdown()
|
||||
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quota_project_id_used_in_client(
|
||||
self,
|
||||
mock_bq_client,
|
||||
mock_to_arrow_schema,
|
||||
mock_asyncio_to_thread,
|
||||
):
|
||||
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
|
||||
mock_creds = mock.create_autospec(
|
||||
google.auth.credentials.Credentials, instance=True, spec_set=True
|
||||
)
|
||||
mock_creds.quota_project_id = "quota-project"
|
||||
with mock.patch.object(
|
||||
google.auth,
|
||||
"default",
|
||||
autospec=True,
|
||||
return_value=(mock_creds, PROJECT_ID),
|
||||
) as mock_auth_default:
|
||||
with mock.patch.object(
|
||||
bigquery_agent_analytics_plugin,
|
||||
"BigQueryWriteAsyncClient",
|
||||
autospec=True,
|
||||
) as mock_bq_write_cls:
|
||||
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
|
||||
project_id=PROJECT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
table_id=TABLE_ID,
|
||||
)
|
||||
await plugin._ensure_started()
|
||||
mock_auth_default.assert_called_once()
|
||||
mock_bq_write_cls.assert_called_once()
|
||||
_, kwargs = mock_bq_write_cls.call_args
|
||||
assert kwargs["client_options"].quota_project_id == "quota-project"
|
||||
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pickle_safety(self, mock_auth_default, mock_bq_client):
|
||||
"""Test that the plugin can be pickled safely."""
|
||||
|
||||
+9
-4
@@ -16,6 +16,7 @@ import json
|
||||
import re
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.application_integration_tool.clients import integration_client
|
||||
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
||||
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
||||
import google.auth
|
||||
@@ -110,6 +111,8 @@ class TestIntegrationClient:
|
||||
mock_credentials,
|
||||
mock_connections_client,
|
||||
):
|
||||
mock_credentials.quota_project_id = "quota-project"
|
||||
mock_credentials.expired = False
|
||||
expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@@ -117,11 +120,12 @@ class TestIntegrationClient:
|
||||
|
||||
with (
|
||||
mock.patch.object(
|
||||
IntegrationClient,
|
||||
"_get_access_token",
|
||||
return_value=mock_credentials.token,
|
||||
integration_client,
|
||||
"default_service_credential",
|
||||
return_value=(mock_credentials, project),
|
||||
),
|
||||
mock.patch("requests.post", return_value=mock_response),
|
||||
mock.patch.object(mock_credentials, "refresh", return_value=None),
|
||||
mock.patch.object(requests, "post", return_value=mock_response),
|
||||
):
|
||||
client = IntegrationClient(
|
||||
project=project,
|
||||
@@ -140,6 +144,7 @@ class TestIntegrationClient:
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {mock_credentials.token}",
|
||||
"x-goog-user-project": "quota-project",
|
||||
},
|
||||
json={
|
||||
"apiTriggerResources": [{
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools import discovery_engine_search_tool
|
||||
from google.adk.tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
|
||||
from google.api_core import exceptions
|
||||
from google.cloud import discoveryengine_v1beta as discoveryengine
|
||||
import pytest
|
||||
|
||||
from google import auth
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.auth.default",
|
||||
@@ -76,10 +79,14 @@ class TestDiscoveryEngineSearchTool:
|
||||
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
|
||||
)
|
||||
|
||||
@mock.patch(
|
||||
"google.cloud.discoveryengine_v1beta.SearchServiceClient",
|
||||
@mock.patch.object(discovery_engine_search_tool, "client_options")
|
||||
@mock.patch.object(
|
||||
discoveryengine,
|
||||
"SearchServiceClient",
|
||||
)
|
||||
def test_discovery_engine_search_success(self, mock_search_client):
|
||||
def test_discovery_engine_search_success(
|
||||
self, mock_search_client, mock_client_options
|
||||
):
|
||||
"""Test successful discovery engine search."""
|
||||
mock_response = discoveryengine.SearchResponse()
|
||||
mock_response.results = [
|
||||
@@ -98,15 +105,28 @@ class TestDiscoveryEngineSearchTool:
|
||||
)
|
||||
]
|
||||
mock_search_client.return_value.search.return_value = mock_response
|
||||
mock_credentials = mock.MagicMock()
|
||||
mock_credentials.quota_project_id = "test-quota-project"
|
||||
|
||||
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")
|
||||
result = tool.discovery_engine_search("test query")
|
||||
with mock.patch.object(
|
||||
auth, "default", return_value=(mock_credentials, "project")
|
||||
) as mock_auth:
|
||||
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")
|
||||
result = tool.discovery_engine_search("test query")
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["title"] == "Test Title"
|
||||
assert result["results"][0]["url"] == "http://example.com"
|
||||
assert result["results"][0]["content"] == "Test Content"
|
||||
assert result["status"] == "success"
|
||||
assert len(result["results"]) == 1
|
||||
assert result["results"][0]["title"] == "Test Title"
|
||||
assert result["results"][0]["url"] == "http://example.com"
|
||||
assert result["results"][0]["content"] == "Test Content"
|
||||
mock_auth.assert_called_once()
|
||||
mock_client_options.ClientOptions.assert_called_once_with(
|
||||
quota_project_id="test-quota-project"
|
||||
)
|
||||
mock_search_client.assert_called_once_with(
|
||||
credentials=mock_credentials,
|
||||
client_options=mock_client_options.ClientOptions.return_value,
|
||||
)
|
||||
|
||||
@mock.patch(
|
||||
"google.cloud.discoveryengine_v1beta.SearchServiceClient",
|
||||
|
||||
Reference in New Issue
Block a user