From 226e873b0f0aa983100d9abeaf8ca43675fb373f Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 8 Jan 2026 12:22:08 -0800 Subject: [PATCH] fix: Ensure consistent ADC quota project override in ADK Fix discovery engine search tool, bigquery agent analytics plugin, and application integration tool to correctly handle the ADC quota project override -- the x-goog-user-project should be set based on the ADC quota project, per gcloud auth team's requirements. Co-authored-by: Kathy Wu PiperOrigin-RevId: 853841124 --- .../bigquery_agent_analytics_plugin.py | 21 ++++++++-- .../clients/integration_client.py | 8 +++- .../adk/tools/discovery_engine_search_tool.py | 9 ++++- .../test_bigquery_agent_analytics_plugin.py | 35 ++++++++++++++++ .../clients/test_integration_client.py | 13 ++++-- .../test_discovery_engine_search_tool.py | 40 ++++++++++++++----- 6 files changed, 106 insertions(+), 20 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 8e0f646d..a80bd0f1 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -36,6 +36,7 @@ from typing import TYPE_CHECKING import uuid import weakref +from google.api_core import client_options from google.api_core.exceptions import InternalServerError from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import TooManyRequests @@ -1352,19 +1353,31 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): if _GLOBAL_WRITE_CLIENT is None: def get_credentials(): - creds, _ = google.auth.default( + creds, project_id = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) - return creds + return creds, project_id - creds = await loop.run_in_executor(self._executor, get_credentials) + creds, project_id = await loop.run_in_executor( + self._executor, get_credentials + ) + quota_project_id = ( + getattr(creds, "quota_project_id", None) or project_id + ) + options = ( + client_options.ClientOptions(quota_project_id=quota_project_id) + if quota_project_id + else None + ) client_info = gapic_client_info.ClientInfo( user_agent=f"google-adk-bq-logger/{__version__}" ) # Initialize the async client in the current event loop, not in the # executor. _GLOBAL_WRITE_CLIENT = BigQueryWriteAsyncClient( - credentials=creds, client_info=client_info + credentials=creds, + client_info=client_info, + client_options=options, ) self.write_client = _GLOBAL_WRITE_CLIENT diff --git a/src/google/adk/tools/application_integration_tool/clients/integration_client.py b/src/google/adk/tools/application_integration_tool/clients/integration_client.py index 0d8789d5..29a73357 100644 --- a/src/google/adk/tools/application_integration_tool/clients/integration_client.py +++ b/src/google/adk/tools/application_integration_tool/clients/integration_client.py @@ -75,6 +75,7 @@ class IntegrationClient: self.actions = actions if actions is not None else [] self.service_account_json = service_account_json self.credential_cache = None + self._quota_project_id = None def get_openapi_spec_for_integration(self): """Gets the OpenAPI spec for the integration. @@ -92,6 +93,8 @@ class IntegrationClient: "Content-Type": "application/json", "Authorization": f"Bearer {self._get_access_token()}", } + if not self.service_account_json: + headers["x-goog-user-project"] = self._quota_project_id or self.project data = { "apiTriggerResources": [ { @@ -247,11 +250,14 @@ class IntegrationClient: ) else: try: - credentials, _ = default_service_credential( + credentials, project_id = default_service_credential( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) except: credentials = None + if credentials: + quota_project_id = getattr(credentials, "quota_project_id", None) + self._quota_project_id = quota_project_id or project_id if not credentials: raise ValueError( diff --git a/src/google/adk/tools/discovery_engine_search_tool.py b/src/google/adk/tools/discovery_engine_search_tool.py index 0e771ece..74e4dd9d 100644 --- a/src/google/adk/tools/discovery_engine_search_tool.py +++ b/src/google/adk/tools/discovery_engine_search_tool.py @@ -17,6 +17,7 @@ from __future__ import annotations from typing import Any from typing import Optional +from google.api_core import client_options from google.api_core.exceptions import GoogleAPICallError import google.auth from google.cloud import discoveryengine_v1beta as discoveryengine @@ -72,8 +73,14 @@ class DiscoveryEngineSearchTool(FunctionTool): self._max_results = max_results credentials, _ = google.auth.default() + quota_project_id = getattr(credentials, "quota_project_id", None) + options = ( + client_options.ClientOptions(quota_project_id=quota_project_id) + if quota_project_id + else None + ) self._discovery_engine_client = discoveryengine.SearchServiceClient( - credentials=credentials + credentials=credentials, client_options=options ) def discovery_engine_search( diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 4c089e3a..a5438cd2 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -1568,6 +1568,41 @@ class TestBigQueryAgentAnalyticsPlugin: await plugin2.shutdown() bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None + @pytest.mark.asyncio + async def test_quota_project_id_used_in_client( + self, + mock_bq_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None + mock_creds = mock.create_autospec( + google.auth.credentials.Credentials, instance=True, spec_set=True + ) + mock_creds.quota_project_id = "quota-project" + with mock.patch.object( + google.auth, + "default", + autospec=True, + return_value=(mock_creds, PROJECT_ID), + ) as mock_auth_default: + with mock.patch.object( + bigquery_agent_analytics_plugin, + "BigQueryWriteAsyncClient", + autospec=True, + ) as mock_bq_write_cls: + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await plugin._ensure_started() + mock_auth_default.assert_called_once() + mock_bq_write_cls.assert_called_once() + _, kwargs = mock_bq_write_cls.call_args + assert kwargs["client_options"].quota_project_id == "quota-project" + bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None + @pytest.mark.asyncio async def test_pickle_safety(self, mock_auth_default, mock_bq_client): """Test that the plugin can be pickled safely.""" diff --git a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py index 7b07442d..c5009e70 100644 --- a/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py +++ b/tests/unittests/tools/application_integration_tool/clients/test_integration_client.py @@ -16,6 +16,7 @@ import json import re from unittest import mock +from google.adk.tools.application_integration_tool.clients import integration_client from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient import google.auth @@ -110,6 +111,8 @@ class TestIntegrationClient: mock_credentials, mock_connections_client, ): + mock_credentials.quota_project_id = "quota-project" + mock_credentials.expired = False expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}} mock_response = mock.MagicMock() mock_response.status_code = 200 @@ -117,11 +120,12 @@ class TestIntegrationClient: with ( mock.patch.object( - IntegrationClient, - "_get_access_token", - return_value=mock_credentials.token, + integration_client, + "default_service_credential", + return_value=(mock_credentials, project), ), - mock.patch("requests.post", return_value=mock_response), + mock.patch.object(mock_credentials, "refresh", return_value=None), + mock.patch.object(requests, "post", return_value=mock_response), ): client = IntegrationClient( project=project, @@ -140,6 +144,7 @@ class TestIntegrationClient: headers={ "Content-Type": "application/json", "Authorization": f"Bearer {mock_credentials.token}", + "x-goog-user-project": "quota-project", }, json={ "apiTriggerResources": [{ diff --git a/tests/unittests/tools/test_discovery_engine_search_tool.py b/tests/unittests/tools/test_discovery_engine_search_tool.py index d10da252..c3525451 100644 --- a/tests/unittests/tools/test_discovery_engine_search_tool.py +++ b/tests/unittests/tools/test_discovery_engine_search_tool.py @@ -14,11 +14,14 @@ from unittest import mock +from google.adk.tools import discovery_engine_search_tool from google.adk.tools.discovery_engine_search_tool import DiscoveryEngineSearchTool from google.api_core import exceptions from google.cloud import discoveryengine_v1beta as discoveryengine import pytest +from google import auth + @mock.patch( "google.auth.default", @@ -76,10 +79,14 @@ class TestDiscoveryEngineSearchTool: data_store_id="test_data_store", data_store_specs=[{"id": "123"}] ) - @mock.patch( - "google.cloud.discoveryengine_v1beta.SearchServiceClient", + @mock.patch.object(discovery_engine_search_tool, "client_options") + @mock.patch.object( + discoveryengine, + "SearchServiceClient", ) - def test_discovery_engine_search_success(self, mock_search_client): + def test_discovery_engine_search_success( + self, mock_search_client, mock_client_options + ): """Test successful discovery engine search.""" mock_response = discoveryengine.SearchResponse() mock_response.results = [ @@ -98,15 +105,28 @@ class TestDiscoveryEngineSearchTool: ) ] mock_search_client.return_value.search.return_value = mock_response + mock_credentials = mock.MagicMock() + mock_credentials.quota_project_id = "test-quota-project" - tool = DiscoveryEngineSearchTool(data_store_id="test_data_store") - result = tool.discovery_engine_search("test query") + with mock.patch.object( + auth, "default", return_value=(mock_credentials, "project") + ) as mock_auth: + tool = DiscoveryEngineSearchTool(data_store_id="test_data_store") + result = tool.discovery_engine_search("test query") - assert result["status"] == "success" - assert len(result["results"]) == 1 - assert result["results"][0]["title"] == "Test Title" - assert result["results"][0]["url"] == "http://example.com" - assert result["results"][0]["content"] == "Test Content" + assert result["status"] == "success" + assert len(result["results"]) == 1 + assert result["results"][0]["title"] == "Test Title" + assert result["results"][0]["url"] == "http://example.com" + assert result["results"][0]["content"] == "Test Content" + mock_auth.assert_called_once() + mock_client_options.ClientOptions.assert_called_once_with( + quota_project_id="test-quota-project" + ) + mock_search_client.assert_called_once_with( + credentials=mock_credentials, + client_options=mock_client_options.ClientOptions.return_value, + ) @mock.patch( "google.cloud.discoveryengine_v1beta.SearchServiceClient",