diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index e205d9be..6160edcc 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -25,6 +25,7 @@ from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator class BaseModelWithConfig(BaseModel): @@ -145,11 +146,45 @@ class ServiceAccountCredential(BaseModelWithConfig): class ServiceAccount(BaseModelWithConfig): - """Represents Google Service Account configuration.""" + """Represents Google Service Account configuration. + + Attributes: + service_account_credential: The service account credential (JSON key). + scopes: The OAuth2 scopes to request. Optional; when omitted with + ``use_default_credential=True``, defaults to the cloud-platform scope. + use_default_credential: Whether to use Application Default Credentials. + use_id_token: Whether to exchange for an ID token instead of an access + token. Required for service-to-service authentication with Cloud Run, + Cloud Functions, and other Google Cloud services that require identity + verification. When True, ``audience`` must also be set. + audience: The target audience for the ID token, typically the URL of the + receiving service (e.g. ``https://my-service-xyz.run.app``). Required + when ``use_id_token`` is True. + """ service_account_credential: Optional[ServiceAccountCredential] = None - scopes: List[str] + scopes: Optional[List[str]] = None use_default_credential: Optional[bool] = False + use_id_token: Optional[bool] = False + audience: Optional[str] = None + + @model_validator(mode="after") + def _validate_config(self) -> ServiceAccount: + if ( + not self.use_default_credential + and self.service_account_credential is None + ): + raise ValueError( + "service_account_credential is required when" + " use_default_credential is False." + ) + if self.use_id_token and not self.audience: + raise ValueError( + "audience is required when use_id_token is True. Set it to the" + " URL of the target service" + " (e.g. 'https://my-service.run.app')." + ) + return self class AuthCredentialTypes(str, Enum): diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 1dbe0fe4..2b79edf9 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -19,6 +19,7 @@ from __future__ import annotations from typing import Optional import google.auth +from google.auth import exceptions as google_auth_exceptions from google.auth.transport.requests import Request from google.oauth2 import service_account import google.oauth2.credentials @@ -27,6 +28,7 @@ from .....auth.auth_credential import AuthCredential from .....auth.auth_credential import AuthCredentialTypes from .....auth.auth_credential import HttpAuth from .....auth.auth_credential import HttpCredentials +from .....auth.auth_credential import ServiceAccount from .....auth.auth_schemes import AuthScheme from .base_credential_exchanger import AuthCredentialMissingError from .base_credential_exchanger import BaseAuthCredentialExchanger @@ -38,6 +40,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): Uses the default service credential if `use_default_credential = True`. Otherwise, uses the service account credential provided in the auth credential. + + Supports exchanging for either an access token (default) or an ID token + when ``ServiceAccount.use_id_token`` is True. ID tokens are required for + service-to-service authentication with Cloud Run, Cloud Functions, and + other services that verify caller identity. """ def exchange_credential( @@ -45,52 +52,130 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): auth_scheme: AuthScheme, auth_credential: Optional[AuthCredential] = None, ) -> AuthCredential: - """Exchanges the service account auth credential for an access token. + """Exchanges the service account auth credential for a token. If auth_credential contains a service account credential, it will be used - to fetch an access token. Otherwise, the default service credential will be - used for fetching an access token. + to fetch a token. Otherwise, the default service credential will be + used for fetching a token. + + When ``service_account.use_id_token`` is True, an ID token is fetched + using the configured ``audience``. This is required for authenticating + to Cloud Run, Cloud Functions, and similar services. Args: auth_scheme: The auth scheme. auth_credential: The auth credential. Returns: - An AuthCredential in HTTPBearer format, containing the access token. + An AuthCredential in HTTPBearer format, containing the token. """ - if ( - auth_credential is None - or auth_credential.service_account is None - or ( - auth_credential.service_account.service_account_credential is None - and not auth_credential.service_account.use_default_credential - ) - ): + if auth_credential is None or auth_credential.service_account is None: raise AuthCredentialMissingError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" + "Service account credentials are missing. Please provide them, or" + " set `use_default_credential = True` to use application default" " credential in a hosted service like Cloud Run." ) + sa_config = auth_credential.service_account + + if sa_config.use_id_token: + return self._exchange_for_id_token(sa_config) + + return self._exchange_for_access_token(sa_config) + + def _exchange_for_id_token(self, sa_config: ServiceAccount) -> AuthCredential: + """Exchanges the service account credential for an ID token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the ID token. + + Raises: + AuthCredentialMissingError: If token exchange fails. + """ + # audience and credential presence are validated by the ServiceAccount + # model_validator at construction time. try: - if auth_credential.service_account.use_default_credential: - credentials, project_id = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - quota_project_id = ( - getattr(credentials, "quota_project_id", None) or project_id - ) + if sa_config.use_default_credential: + from google.oauth2 import id_token as oauth2_id_token + + request = Request() + token = oauth2_id_token.fetch_id_token(request, sa_config.audience) else: - config = auth_credential.service_account + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None + credentials = ( + service_account.IDTokenCredentials.from_service_account_info( + sa_config.service_account_credential.model_dump(), + target_audience=sa_config.audience, + ) + ) + credentials.refresh(Request()) + token = credentials.token + + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=token), + ), + ) + + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key), or when + # fetch_id_token cannot determine credentials from the environment. + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: + raise AuthCredentialMissingError( + f"Failed to exchange service account for ID token: {e}" + ) from e + + def _exchange_for_access_token( + self, sa_config: ServiceAccount + ) -> AuthCredential: + """Exchanges the service account credential for an access token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the access token. + + Raises: + AuthCredentialMissingError: If scopes are missing for explicit + credentials or token exchange fails. + """ + if not sa_config.use_default_credential and not sa_config.scopes: + raise AuthCredentialMissingError( + "scopes are required when using explicit service account credentials" + " for access token exchange." + ) + + try: + if sa_config.use_default_credential: + scopes = ( + sa_config.scopes + if sa_config.scopes + else ["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials, project_id = google.auth.default( + scopes=scopes, + ) + quota_project_id = credentials.quota_project_id or project_id + else: + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes + sa_config.service_account_credential.model_dump(), + scopes=sa_config.scopes, ) quota_project_id = None credentials.refresh(Request()) - updated_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=HttpAuth( scheme="bearer", credentials=HttpCredentials(token=credentials.token), @@ -101,9 +186,10 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): else None, ), ) - return updated_credential - except Exception as e: + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key). + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: raise AuthCredentialMissingError( f"Failed to exchange service account token: {e}" ) from e diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index c4c85e77..f38a8bbc 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -534,7 +534,9 @@ class TestMCPTool: ) # Create service account credential - service_account = ServiceAccount(scopes=["test"]) + service_account = ServiceAccount( + scopes=["test"], use_default_credential=True + ) credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=service_account, diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index 0ca99444..fb35daf6 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -25,8 +25,23 @@ from google.adk.auth.auth_schemes import AuthSchemeType from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger import google.auth +from google.auth import exceptions as google_auth_exceptions import pytest +_ACCESS_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.Credentials." + "from_service_account_info" +) + +_ID_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.IDTokenCredentials." + "from_service_account_info" +) + +_FETCH_ID_TOKEN_MONKEYPATCH_TARGET = "google.oauth2.id_token.fetch_id_token" + @pytest.fixture def service_account_exchanger(): @@ -41,50 +56,45 @@ def auth_scheme(): return scheme -def test_exchange_credential_success( - service_account_exchanger, auth_scheme, monkeypatch +@pytest.fixture +def sa_credential(): + """A minimal valid ServiceAccountCredential for testing.""" + return ServiceAccountCredential( + type_="service_account", + project_id="test_project_id", + private_key_id="test_private_key_id", + private_key="-----BEGIN PRIVATE KEY-----...", + client_email="test@test.iam.gserviceaccount.com", + client_id="test_client_id", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs", + client_x509_cert_url=( + "https://www.googleapis.com/robot/v1/metadata/x509/test" + ), + universe_domain="googleapis.com", + ) + + +_DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] + + +# --- Access token exchange tests --- + + +def test_exchange_access_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test successful exchange of service account credentials.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_from_sa_info = MagicMock(return_value=mock_credentials) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) - # Mock the from_service_account_info method - mock_from_service_account_info = MagicMock(return_value=mock_credentials) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, - ) - - # Mock the refresh method - mock_credentials.refresh = MagicMock() - - # Create a valid AuthCredential with service account info auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) @@ -95,7 +105,7 @@ def test_exchange_credential_success( assert result.auth_type == AuthCredentialTypes.HTTP assert result.http.scheme == "bearer" assert result.http.credentials.token == "mock_access_token" - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() mock_credentials.refresh.assert_called_once() @@ -107,7 +117,7 @@ def test_exchange_credential_success( (None, None, None), ], ) -def test_exchange_credential_use_default_credential_success( +def test_exchange_access_token_with_adc_sets_quota_project( service_account_exchanger, auth_scheme, monkeypatch, @@ -115,7 +125,6 @@ def test_exchange_credential_use_default_credential_success( adc_project_id, expected_quota_project_id, ): - """Test successful exchange of service account credentials using default credential.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" mock_credentials.quota_project_id = cred_quota_project_id @@ -128,7 +137,7 @@ def test_exchange_credential_use_default_credential_success( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], + scopes=["https://www.googleapis.com/auth/bigquery"], ), ) @@ -146,26 +155,49 @@ def test_exchange_credential_use_default_credential_success( ) else: assert not result.http.additional_headers - # Verify google.auth.default is called with the correct scopes parameter mock_google_auth_default.assert_called_once_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"] + scopes=["https://www.googleapis.com/auth/bigquery"] ) mock_credentials.refresh.assert_called_once() -def test_exchange_credential_missing_auth_credential( +def test_exchange_access_token_with_adc_defaults_to_cloud_platform_scope( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_credentials = MagicMock() + mock_credentials.token = "mock_access_token" + mock_credentials.quota_project_id = None + mock_google_auth_default = MagicMock(return_value=(mock_credentials, None)) + monkeypatch.setattr(google.auth, "default", mock_google_auth_default) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_access_token" + mock_google_auth_default.assert_called_once_with(scopes=_DEFAULT_SCOPES) + + +def test_exchange_raises_when_auth_credential_is_none( service_account_exchanger, auth_scheme ): - """Test missing auth credential during exchange.""" with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, None) assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_missing_service_account_info( +def test_exchange_raises_when_service_account_is_none( service_account_exchanger, auth_scheme ): - """Test missing service account info during exchange.""" auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, ) @@ -174,47 +206,188 @@ def test_exchange_credential_missing_service_account_info( assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_exchange_failure( - service_account_exchanger, auth_scheme, monkeypatch +def test_exchange_wraps_google_auth_error_as_missing_error( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test failure during service account token exchange.""" - mock_from_service_account_info = MagicMock( - side_effect=Exception("Failed to load credentials") - ) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to load credentials") ) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) + with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, auth_credential) assert "Failed to exchange service account token" in str(exc_info.value) - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() + + +def test_exchange_raises_when_explicit_credentials_have_no_scopes( + service_account_exchanger, auth_scheme, sa_credential +): + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "scopes are required" in str(exc_info.value) + + +# --- ID token exchange tests --- + + +def test_exchange_id_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_id_credentials = MagicMock() + mock_id_credentials.token = "mock_id_token" + mock_from_sa_info = MagicMock(return_value=mock_id_credentials) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_id_token" + assert result.http.additional_headers is None + mock_from_sa_info.assert_called_once() + assert ( + mock_from_sa_info.call_args[1]["target_audience"] + == "https://my-service.run.app" + ) + mock_id_credentials.refresh.assert_called_once() + + +def test_exchange_id_token_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock(return_value="mock_adc_id_token") + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_adc_id_token" + assert result.http.additional_headers is None + mock_fetch_id_token.assert_called_once() + assert mock_fetch_id_token.call_args[0][1] == "https://my-service.run.app" + + +def test_id_token_requires_audience(): + with pytest.raises( + ValueError, match="audience is required when use_id_token is True" + ): + ServiceAccount( + use_default_credential=True, + use_id_token=True, + ) + + +def test_exchange_id_token_wraps_error_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to create ID token credentials") + ) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +def test_exchange_id_token_wraps_error_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock( + side_effect=google_auth_exceptions.DefaultCredentialsError( + "Metadata service unavailable" + ) + ) + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +# --- Model validator tests --- + + +def test_model_validator_rejects_missing_credential_without_adc(): + with pytest.raises( + ValueError, + match="service_account_credential is required", + ): + ServiceAccount( + use_default_credential=False, + scopes=_DEFAULT_SCOPES, + ) + + +def test_model_validator_allows_adc_without_explicit_credential(): + sa = ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + ) + assert sa.service_account_credential is None + assert sa.use_default_credential is True