fix: Ensure open api tool service account exchanger uses quota project id for ADC

gcloud auth team requested that we audit ADK's codebase for places where ADC (google.auth.default) is used, and make sure that the quota project id header is being populated.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 855322964
This commit is contained in:
Kathy Wu
2026-01-12 11:51:31 -08:00
committed by Copybara-Service
parent 5880109ab1
commit 7c8bc69dd0
5 changed files with 74 additions and 3 deletions
+1
View File
@@ -61,6 +61,7 @@ class HttpAuth(BaseModelWithConfig):
# Examples: 'basic', 'bearer'
scheme: str
credentials: HttpCredentials
additional_headers: Optional[Dict[str, str]] = None
class OAuth2Auth(BaseModelWithConfig):
@@ -74,14 +74,18 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
try:
if auth_credential.service_account.use_default_credential:
credentials, _ = google.auth.default(
credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
quota_project_id = (
getattr(credentials, "quota_project_id", None) or project_id
)
else:
config = auth_credential.service_account
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(), scopes=config.scopes
)
quota_project_id = None
credentials.refresh(Request())
@@ -90,6 +94,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credentials.token),
additional_headers={
"x-goog-user-project": quota_project_id,
}
if quota_project_id
else None,
),
)
return updated_credential
@@ -320,6 +320,13 @@ class RestApiTool(BaseTool):
user_agent = f"google-adk/{adk_version} (tool: {self.name})"
header_params["User-Agent"] = user_agent
if (
self.auth_credential
and self.auth_credential.http
and self.auth_credential.http.additional_headers
):
header_params.update(self.auth_credential.http.additional_headers)
params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters}
# Fill in path, query, header and cookie parameters to the request
@@ -99,14 +99,28 @@ def test_exchange_credential_success(
mock_credentials.refresh.assert_called_once()
@pytest.mark.parametrize(
"cred_quota_project_id, adc_project_id, expected_quota_project_id",
[
("test_project", "another_project", "test_project"),
(None, "adc_project", "adc_project"),
(None, None, None),
],
)
def test_exchange_credential_use_default_credential_success(
service_account_exchanger, auth_scheme, monkeypatch
service_account_exchanger,
auth_scheme,
monkeypatch,
cred_quota_project_id,
adc_project_id,
expected_quota_project_id,
):
"""Test successful exchange of service account credentials using default credential."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
mock_credentials.quota_project_id = cred_quota_project_id
mock_google_auth_default = MagicMock(
return_value=(mock_credentials, "test_project")
return_value=(mock_credentials, adc_project_id)
)
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
@@ -125,6 +139,13 @@ def test_exchange_credential_use_default_credential_success(
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
if expected_quota_project_id:
assert (
result.http.additional_headers["x-goog-user-project"]
== expected_quota_project_id
)
else:
assert not result.http.additional_headers
# Verify google.auth.default is called with the correct scopes parameter
mock_google_auth_default.assert_called_once_with(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
@@ -25,6 +25,10 @@ from fastapi.openapi.models import Operation
from fastapi.openapi.models import Parameter as OpenAPIParameter
from fastapi.openapi.models import RequestBody
from fastapi.openapi.models import Schema as OpenAPISchema
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.sessions.state import State
from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential
from google.adk.tools.openapi_tool.common.common import ApiParameter
@@ -721,6 +725,35 @@ class TestRestApiTool:
assert request_params["cookies"]["session_id"] == "cookie_value"
def test_prepare_request_params_quota_project_id(
self,
sample_endpoint,
sample_operation,
sample_auth_scheme,
):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(),
additional_headers={"x-goog-user-project": "test-project"},
),
)
tool = RestApiTool(
name="test_tool",
description="Test Tool",
endpoint=sample_endpoint,
operation=sample_operation,
auth_credential=auth_credential,
auth_scheme=sample_auth_scheme,
)
params = []
kwargs = {}
request_params = tool._prepare_request_params(params, kwargs)
assert request_params["headers"]["x-goog-user-project"] == "test-project"
def test_prepare_request_params_multiple_mime_types(
self, sample_endpoint, sample_auth_credential, sample_auth_scheme
):