fix: Ensure consistent ADC quota project override in ADK

Fix discovery engine search tool, bigquery agent analytics plugin, and application integration tool to correctly handle the ADC quota project override -- the x-goog-user-project should be set based on the ADC quota project, per gcloud auth team's requirements.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 853841124
This commit is contained in:
Kathy Wu
2026-01-08 12:22:08 -08:00
committed by Copybara-Service
parent 8afb99a078
commit 226e873b0f
6 changed files with 106 additions and 20 deletions
@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING
import uuid
import weakref
from google.api_core import client_options
from google.api_core.exceptions import InternalServerError
from google.api_core.exceptions import ServiceUnavailable
from google.api_core.exceptions import TooManyRequests
@@ -1352,19 +1353,31 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
if _GLOBAL_WRITE_CLIENT is None:
def get_credentials():
creds, _ = google.auth.default(
creds, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
return creds
return creds, project_id
creds = await loop.run_in_executor(self._executor, get_credentials)
creds, project_id = await loop.run_in_executor(
self._executor, get_credentials
)
quota_project_id = (
getattr(creds, "quota_project_id", None) or project_id
)
options = (
client_options.ClientOptions(quota_project_id=quota_project_id)
if quota_project_id
else None
)
client_info = gapic_client_info.ClientInfo(
user_agent=f"google-adk-bq-logger/{__version__}"
)
# Initialize the async client in the current event loop, not in the
# executor.
_GLOBAL_WRITE_CLIENT = BigQueryWriteAsyncClient(
credentials=creds, client_info=client_info
credentials=creds,
client_info=client_info,
client_options=options,
)
self.write_client = _GLOBAL_WRITE_CLIENT
@@ -75,6 +75,7 @@ class IntegrationClient:
self.actions = actions if actions is not None else []
self.service_account_json = service_account_json
self.credential_cache = None
self._quota_project_id = None
def get_openapi_spec_for_integration(self):
"""Gets the OpenAPI spec for the integration.
@@ -92,6 +93,8 @@ class IntegrationClient:
"Content-Type": "application/json",
"Authorization": f"Bearer {self._get_access_token()}",
}
if not self.service_account_json:
headers["x-goog-user-project"] = self._quota_project_id or self.project
data = {
"apiTriggerResources": [
{
@@ -247,11 +250,14 @@ class IntegrationClient:
)
else:
try:
credentials, _ = default_service_credential(
credentials, project_id = default_service_credential(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
except:
credentials = None
if credentials:
quota_project_id = getattr(credentials, "quota_project_id", None)
self._quota_project_id = quota_project_id or project_id
if not credentials:
raise ValueError(
@@ -17,6 +17,7 @@ from __future__ import annotations
from typing import Any
from typing import Optional
from google.api_core import client_options
from google.api_core.exceptions import GoogleAPICallError
import google.auth
from google.cloud import discoveryengine_v1beta as discoveryengine
@@ -72,8 +73,14 @@ class DiscoveryEngineSearchTool(FunctionTool):
self._max_results = max_results
credentials, _ = google.auth.default()
quota_project_id = getattr(credentials, "quota_project_id", None)
options = (
client_options.ClientOptions(quota_project_id=quota_project_id)
if quota_project_id
else None
)
self._discovery_engine_client = discoveryengine.SearchServiceClient(
credentials=credentials
credentials=credentials, client_options=options
)
def discovery_engine_search(
@@ -1568,6 +1568,41 @@ class TestBigQueryAgentAnalyticsPlugin:
await plugin2.shutdown()
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
@pytest.mark.asyncio
async def test_quota_project_id_used_in_client(
self,
mock_bq_client,
mock_to_arrow_schema,
mock_asyncio_to_thread,
):
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
mock_creds = mock.create_autospec(
google.auth.credentials.Credentials, instance=True, spec_set=True
)
mock_creds.quota_project_id = "quota-project"
with mock.patch.object(
google.auth,
"default",
autospec=True,
return_value=(mock_creds, PROJECT_ID),
) as mock_auth_default:
with mock.patch.object(
bigquery_agent_analytics_plugin,
"BigQueryWriteAsyncClient",
autospec=True,
) as mock_bq_write_cls:
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
project_id=PROJECT_ID,
dataset_id=DATASET_ID,
table_id=TABLE_ID,
)
await plugin._ensure_started()
mock_auth_default.assert_called_once()
mock_bq_write_cls.assert_called_once()
_, kwargs = mock_bq_write_cls.call_args
assert kwargs["client_options"].quota_project_id == "quota-project"
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
@pytest.mark.asyncio
async def test_pickle_safety(self, mock_auth_default, mock_bq_client):
"""Test that the plugin can be pickled safely."""
@@ -16,6 +16,7 @@ import json
import re
from unittest import mock
from google.adk.tools.application_integration_tool.clients import integration_client
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
import google.auth
@@ -110,6 +111,8 @@ class TestIntegrationClient:
mock_credentials,
mock_connections_client,
):
mock_credentials.quota_project_id = "quota-project"
mock_credentials.expired = False
expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
mock_response = mock.MagicMock()
mock_response.status_code = 200
@@ -117,11 +120,12 @@ class TestIntegrationClient:
with (
mock.patch.object(
IntegrationClient,
"_get_access_token",
return_value=mock_credentials.token,
integration_client,
"default_service_credential",
return_value=(mock_credentials, project),
),
mock.patch("requests.post", return_value=mock_response),
mock.patch.object(mock_credentials, "refresh", return_value=None),
mock.patch.object(requests, "post", return_value=mock_response),
):
client = IntegrationClient(
project=project,
@@ -140,6 +144,7 @@ class TestIntegrationClient:
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {mock_credentials.token}",
"x-goog-user-project": "quota-project",
},
json={
"apiTriggerResources": [{
@@ -14,11 +14,14 @@
from unittest import mock
from google.adk.tools import discovery_engine_search_tool
from google.adk.tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
from google.api_core import exceptions
from google.cloud import discoveryengine_v1beta as discoveryengine
import pytest
from google import auth
@mock.patch(
"google.auth.default",
@@ -76,10 +79,14 @@ class TestDiscoveryEngineSearchTool:
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
)
@mock.patch(
"google.cloud.discoveryengine_v1beta.SearchServiceClient",
@mock.patch.object(discovery_engine_search_tool, "client_options")
@mock.patch.object(
discoveryengine,
"SearchServiceClient",
)
def test_discovery_engine_search_success(self, mock_search_client):
def test_discovery_engine_search_success(
self, mock_search_client, mock_client_options
):
"""Test successful discovery engine search."""
mock_response = discoveryengine.SearchResponse()
mock_response.results = [
@@ -98,15 +105,28 @@ class TestDiscoveryEngineSearchTool:
)
]
mock_search_client.return_value.search.return_value = mock_response
mock_credentials = mock.MagicMock()
mock_credentials.quota_project_id = "test-quota-project"
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")
result = tool.discovery_engine_search("test query")
with mock.patch.object(
auth, "default", return_value=(mock_credentials, "project")
) as mock_auth:
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")
result = tool.discovery_engine_search("test query")
assert result["status"] == "success"
assert len(result["results"]) == 1
assert result["results"][0]["title"] == "Test Title"
assert result["results"][0]["url"] == "http://example.com"
assert result["results"][0]["content"] == "Test Content"
assert result["status"] == "success"
assert len(result["results"]) == 1
assert result["results"][0]["title"] == "Test Title"
assert result["results"][0]["url"] == "http://example.com"
assert result["results"][0]["content"] == "Test Content"
mock_auth.assert_called_once()
mock_client_options.ClientOptions.assert_called_once_with(
quota_project_id="test-quota-project"
)
mock_search_client.assert_called_once_with(
credentials=mock_credentials,
client_options=mock_client_options.ClientOptions.return_value,
)
@mock.patch(
"google.cloud.discoveryengine_v1beta.SearchServiceClient",