diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index f707d6a0..d659bdaa 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -61,6 +61,7 @@ class HttpAuth(BaseModelWithConfig): # Examples: 'basic', 'bearer' scheme: str credentials: HttpCredentials + additional_headers: Optional[Dict[str, str]] = None class OAuth2Auth(BaseModelWithConfig): diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 4fdc8701..6044cac2 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -74,14 +74,18 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): try: if auth_credential.service_account.use_default_credential: - credentials, _ = google.auth.default( + credentials, project_id = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], ) + quota_project_id = ( + getattr(credentials, "quota_project_id", None) or project_id + ) else: config = auth_credential.service_account credentials = service_account.Credentials.from_service_account_info( config.service_account_credential.model_dump(), scopes=config.scopes ) + quota_project_id = None credentials.refresh(Request()) @@ -90,6 +94,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): http=HttpAuth( scheme="bearer", credentials=HttpCredentials(token=credentials.token), + additional_headers={ + "x-goog-user-project": quota_project_id, + } + if quota_project_id + else None, ), ) return updated_credential diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 5c27b168..27c6acda 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -320,6 +320,13 @@ class RestApiTool(BaseTool): user_agent = f"google-adk/{adk_version} (tool: {self.name})" header_params["User-Agent"] = user_agent + if ( + self.auth_credential + and self.auth_credential.http + and self.auth_credential.http.additional_headers + ): + header_params.update(self.auth_credential.http.additional_headers) + params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters} # Fill in path, query, header and cookie parameters to the request diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index db929c8e..4d930b39 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -99,14 +99,28 @@ def test_exchange_credential_success( mock_credentials.refresh.assert_called_once() +@pytest.mark.parametrize( + "cred_quota_project_id, adc_project_id, expected_quota_project_id", + [ + ("test_project", "another_project", "test_project"), + (None, "adc_project", "adc_project"), + (None, None, None), + ], +) def test_exchange_credential_use_default_credential_success( - service_account_exchanger, auth_scheme, monkeypatch + service_account_exchanger, + auth_scheme, + monkeypatch, + cred_quota_project_id, + adc_project_id, + expected_quota_project_id, ): """Test successful exchange of service account credentials using default credential.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_credentials.quota_project_id = cred_quota_project_id mock_google_auth_default = MagicMock( - return_value=(mock_credentials, "test_project") + return_value=(mock_credentials, adc_project_id) ) monkeypatch.setattr(google.auth, "default", mock_google_auth_default) @@ -125,6 +139,13 @@ def test_exchange_credential_use_default_credential_success( assert result.auth_type == AuthCredentialTypes.HTTP assert result.http.scheme == "bearer" assert result.http.credentials.token == "mock_access_token" + if expected_quota_project_id: + assert ( + result.http.additional_headers["x-goog-user-project"] + == expected_quota_project_id + ) + else: + assert not result.http.additional_headers # Verify google.auth.default is called with the correct scopes parameter mock_google_auth_default.assert_called_once_with( scopes=["https://www.googleapis.com/auth/cloud-platform"] diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 560813e6..ddf09aeb 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -25,6 +25,10 @@ from fastapi.openapi.models import Operation from fastapi.openapi.models import Parameter as OpenAPIParameter from fastapi.openapi.models import RequestBody from fastapi.openapi.models import Schema as OpenAPISchema +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials from google.adk.sessions.state import State from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential from google.adk.tools.openapi_tool.common.common import ApiParameter @@ -721,6 +725,35 @@ class TestRestApiTool: assert request_params["cookies"]["session_id"] == "cookie_value" + def test_prepare_request_params_quota_project_id( + self, + sample_endpoint, + sample_operation, + sample_auth_scheme, + ): + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(), + additional_headers={"x-goog-user-project": "test-project"}, + ), + ) + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + auth_credential=auth_credential, + auth_scheme=sample_auth_scheme, + ) + params = [] + kwargs = {} + + request_params = tool._prepare_request_params(params, kwargs) + + assert request_params["headers"]["x-goog-user-project"] == "test-project" + def test_prepare_request_params_multiple_mime_types( self, sample_endpoint, sample_auth_credential, sample_auth_scheme ):