feat: Support ID token exchange in ServiceAccountCredentialExchanger

Adds use_id_token and audience fields to ServiceAccount so that
ServiceAccountCredentialExchanger can produce ID tokens instead of
access tokens. This is required for authenticating to Cloud Run, Cloud
Functions, and other Google Cloud services that verify caller identity.
Close #4458

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 874630210
This commit is contained in:
George Weale
2026-02-24 08:38:34 -08:00
committed by Copybara-Service
parent c615757ba1
commit 7be90db24b
4 changed files with 406 additions and 110 deletions
+37 -2
View File
@@ -25,6 +25,7 @@ from pydantic import alias_generators
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
class BaseModelWithConfig(BaseModel):
@@ -145,11 +146,45 @@ class ServiceAccountCredential(BaseModelWithConfig):
class ServiceAccount(BaseModelWithConfig):
"""Represents Google Service Account configuration."""
"""Represents Google Service Account configuration.
Attributes:
service_account_credential: The service account credential (JSON key).
scopes: The OAuth2 scopes to request. Optional; when omitted with
``use_default_credential=True``, defaults to the cloud-platform scope.
use_default_credential: Whether to use Application Default Credentials.
use_id_token: Whether to exchange for an ID token instead of an access
token. Required for service-to-service authentication with Cloud Run,
Cloud Functions, and other Google Cloud services that require identity
verification. When True, ``audience`` must also be set.
audience: The target audience for the ID token, typically the URL of the
receiving service (e.g. ``https://my-service-xyz.run.app``). Required
when ``use_id_token`` is True.
"""
service_account_credential: Optional[ServiceAccountCredential] = None
scopes: List[str]
scopes: Optional[List[str]] = None
use_default_credential: Optional[bool] = False
use_id_token: Optional[bool] = False
audience: Optional[str] = None
@model_validator(mode="after")
def _validate_config(self) -> ServiceAccount:
if (
not self.use_default_credential
and self.service_account_credential is None
):
raise ValueError(
"service_account_credential is required when"
" use_default_credential is False."
)
if self.use_id_token and not self.audience:
raise ValueError(
"audience is required when use_id_token is True. Set it to the"
" URL of the target service"
" (e.g. 'https://my-service.run.app')."
)
return self
class AuthCredentialTypes(str, Enum):
@@ -19,6 +19,7 @@ from __future__ import annotations
from typing import Optional
import google.auth
from google.auth import exceptions as google_auth_exceptions
from google.auth.transport.requests import Request
from google.oauth2 import service_account
import google.oauth2.credentials
@@ -27,6 +28,7 @@ from .....auth.auth_credential import AuthCredential
from .....auth.auth_credential import AuthCredentialTypes
from .....auth.auth_credential import HttpAuth
from .....auth.auth_credential import HttpCredentials
from .....auth.auth_credential import ServiceAccount
from .....auth.auth_schemes import AuthScheme
from .base_credential_exchanger import AuthCredentialMissingError
from .base_credential_exchanger import BaseAuthCredentialExchanger
@@ -38,6 +40,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
Uses the default service credential if `use_default_credential = True`.
Otherwise, uses the service account credential provided in the auth
credential.
Supports exchanging for either an access token (default) or an ID token
when ``ServiceAccount.use_id_token`` is True. ID tokens are required for
service-to-service authentication with Cloud Run, Cloud Functions, and
other services that verify caller identity.
"""
def exchange_credential(
@@ -45,52 +52,130 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
auth_scheme: AuthScheme,
auth_credential: Optional[AuthCredential] = None,
) -> AuthCredential:
"""Exchanges the service account auth credential for an access token.
"""Exchanges the service account auth credential for a token.
If auth_credential contains a service account credential, it will be used
to fetch an access token. Otherwise, the default service credential will be
used for fetching an access token.
to fetch a token. Otherwise, the default service credential will be
used for fetching a token.
When ``service_account.use_id_token`` is True, an ID token is fetched
using the configured ``audience``. This is required for authenticating
to Cloud Run, Cloud Functions, and similar services.
Args:
auth_scheme: The auth scheme.
auth_credential: The auth credential.
Returns:
An AuthCredential in HTTPBearer format, containing the access token.
An AuthCredential in HTTPBearer format, containing the token.
"""
if (
auth_credential is None
or auth_credential.service_account is None
or (
auth_credential.service_account.service_account_credential is None
and not auth_credential.service_account.use_default_credential
)
):
if auth_credential is None or auth_credential.service_account is None:
raise AuthCredentialMissingError(
"Service account credentials are missing. Please provide them, or set"
" `use_default_credential = True` to use application default"
"Service account credentials are missing. Please provide them, or"
" set `use_default_credential = True` to use application default"
" credential in a hosted service like Cloud Run."
)
sa_config = auth_credential.service_account
if sa_config.use_id_token:
return self._exchange_for_id_token(sa_config)
return self._exchange_for_access_token(sa_config)
def _exchange_for_id_token(self, sa_config: ServiceAccount) -> AuthCredential:
"""Exchanges the service account credential for an ID token.
Args:
sa_config: The service account configuration.
Returns:
An AuthCredential in HTTPBearer format containing the ID token.
Raises:
AuthCredentialMissingError: If token exchange fails.
"""
# audience and credential presence are validated by the ServiceAccount
# model_validator at construction time.
try:
if auth_credential.service_account.use_default_credential:
credentials, project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
quota_project_id = (
getattr(credentials, "quota_project_id", None) or project_id
)
if sa_config.use_default_credential:
from google.oauth2 import id_token as oauth2_id_token
request = Request()
token = oauth2_id_token.fetch_id_token(request, sa_config.audience)
else:
config = auth_credential.service_account
# Guaranteed non-None by ServiceAccount model_validator.
assert sa_config.service_account_credential is not None
credentials = (
service_account.IDTokenCredentials.from_service_account_info(
sa_config.service_account_credential.model_dump(),
target_audience=sa_config.audience,
)
)
credentials.refresh(Request())
token = credentials.token
return AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=token),
),
)
# ValueError is raised by google-auth when service account JSON is
# missing required fields (e.g. client_email, private_key), or when
# fetch_id_token cannot determine credentials from the environment.
except (google_auth_exceptions.GoogleAuthError, ValueError) as e:
raise AuthCredentialMissingError(
f"Failed to exchange service account for ID token: {e}"
) from e
def _exchange_for_access_token(
self, sa_config: ServiceAccount
) -> AuthCredential:
"""Exchanges the service account credential for an access token.
Args:
sa_config: The service account configuration.
Returns:
An AuthCredential in HTTPBearer format containing the access token.
Raises:
AuthCredentialMissingError: If scopes are missing for explicit
credentials or token exchange fails.
"""
if not sa_config.use_default_credential and not sa_config.scopes:
raise AuthCredentialMissingError(
"scopes are required when using explicit service account credentials"
" for access token exchange."
)
try:
if sa_config.use_default_credential:
scopes = (
sa_config.scopes
if sa_config.scopes
else ["https://www.googleapis.com/auth/cloud-platform"]
)
credentials, project_id = google.auth.default(
scopes=scopes,
)
quota_project_id = credentials.quota_project_id or project_id
else:
# Guaranteed non-None by ServiceAccount model_validator.
assert sa_config.service_account_credential is not None
credentials = service_account.Credentials.from_service_account_info(
config.service_account_credential.model_dump(), scopes=config.scopes
sa_config.service_account_credential.model_dump(),
scopes=sa_config.scopes,
)
quota_project_id = None
credentials.refresh(Request())
updated_credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
return AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer",
credentials=HttpCredentials(token=credentials.token),
@@ -101,9 +186,10 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
else None,
),
)
return updated_credential
except Exception as e:
# ValueError is raised by google-auth when service account JSON is
# missing required fields (e.g. client_email, private_key).
except (google_auth_exceptions.GoogleAuthError, ValueError) as e:
raise AuthCredentialMissingError(
f"Failed to exchange service account token: {e}"
) from e
@@ -534,7 +534,9 @@ class TestMCPTool:
)
# Create service account credential
service_account = ServiceAccount(scopes=["test"])
service_account = ServiceAccount(
scopes=["test"], use_default_credential=True
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=service_account,
@@ -25,8 +25,23 @@ from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
import google.auth
from google.auth import exceptions as google_auth_exceptions
import pytest
_ACCESS_TOKEN_MONKEYPATCH_TARGET = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
_ID_TOKEN_MONKEYPATCH_TARGET = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.IDTokenCredentials."
"from_service_account_info"
)
_FETCH_ID_TOKEN_MONKEYPATCH_TARGET = "google.oauth2.id_token.fetch_id_token"
@pytest.fixture
def service_account_exchanger():
@@ -41,50 +56,45 @@ def auth_scheme():
return scheme
def test_exchange_credential_success(
service_account_exchanger, auth_scheme, monkeypatch
@pytest.fixture
def sa_credential():
"""A minimal valid ServiceAccountCredential for testing."""
return ServiceAccountCredential(
type_="service_account",
project_id="test_project_id",
private_key_id="test_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="test@test.iam.gserviceaccount.com",
client_id="test_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/test"
),
universe_domain="googleapis.com",
)
_DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
# --- Access token exchange tests ---
def test_exchange_access_token_with_explicit_credentials(
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
):
"""Test successful exchange of service account credentials."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
mock_from_sa_info = MagicMock(return_value=mock_credentials)
monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
# Mock the from_service_account_info method
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
monkeypatch.setattr(
target_path,
mock_from_service_account_info,
)
# Mock the refresh method
mock_credentials.refresh = MagicMock()
# Create a valid AuthCredential with service account info
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
service_account_credential=sa_credential,
scopes=_DEFAULT_SCOPES,
),
)
@@ -95,7 +105,7 @@ def test_exchange_credential_success(
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
mock_from_service_account_info.assert_called_once()
mock_from_sa_info.assert_called_once()
mock_credentials.refresh.assert_called_once()
@@ -107,7 +117,7 @@ def test_exchange_credential_success(
(None, None, None),
],
)
def test_exchange_credential_use_default_credential_success(
def test_exchange_access_token_with_adc_sets_quota_project(
service_account_exchanger,
auth_scheme,
monkeypatch,
@@ -115,7 +125,6 @@ def test_exchange_credential_use_default_credential_success(
adc_project_id,
expected_quota_project_id,
):
"""Test successful exchange of service account credentials using default credential."""
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
mock_credentials.quota_project_id = cred_quota_project_id
@@ -128,7 +137,7 @@ def test_exchange_credential_use_default_credential_success(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
scopes=["https://www.googleapis.com/auth/bigquery"],
),
)
@@ -146,26 +155,49 @@ def test_exchange_credential_use_default_credential_success(
)
else:
assert not result.http.additional_headers
# Verify google.auth.default is called with the correct scopes parameter
mock_google_auth_default.assert_called_once_with(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
scopes=["https://www.googleapis.com/auth/bigquery"]
)
mock_credentials.refresh.assert_called_once()
def test_exchange_credential_missing_auth_credential(
def test_exchange_access_token_with_adc_defaults_to_cloud_platform_scope(
service_account_exchanger, auth_scheme, monkeypatch
):
mock_credentials = MagicMock()
mock_credentials.token = "mock_access_token"
mock_credentials.quota_project_id = None
mock_google_auth_default = MagicMock(return_value=(mock_credentials, None))
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_access_token"
mock_google_auth_default.assert_called_once_with(scopes=_DEFAULT_SCOPES)
def test_exchange_raises_when_auth_credential_is_none(
service_account_exchanger, auth_scheme
):
"""Test missing auth credential during exchange."""
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, None)
assert "Service account credentials are missing" in str(exc_info.value)
def test_exchange_credential_missing_service_account_info(
def test_exchange_raises_when_service_account_is_none(
service_account_exchanger, auth_scheme
):
"""Test missing service account info during exchange."""
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
)
@@ -174,47 +206,188 @@ def test_exchange_credential_missing_service_account_info(
assert "Service account credentials are missing" in str(exc_info.value)
def test_exchange_credential_exchange_failure(
service_account_exchanger, auth_scheme, monkeypatch
def test_exchange_wraps_google_auth_error_as_missing_error(
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
):
"""Test failure during service account token exchange."""
mock_from_service_account_info = MagicMock(
side_effect=Exception("Failed to load credentials")
)
target_path = (
"google.adk.tools.openapi_tool.auth.credential_exchangers."
"service_account_exchanger.service_account.Credentials."
"from_service_account_info"
)
monkeypatch.setattr(
target_path,
mock_from_service_account_info,
mock_from_sa_info = MagicMock(
side_effect=ValueError("Failed to load credentials")
)
monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=ServiceAccountCredential(
type_="service_account",
project_id="your_project_id",
private_key_id="your_private_key_id",
private_key="-----BEGIN PRIVATE KEY-----...",
client_email="...@....iam.gserviceaccount.com",
client_id="your_client_id",
auth_uri="https://accounts.google.com/o/oauth2/auth",
token_uri="https://oauth2.googleapis.com/token",
auth_provider_x509_cert_url=(
"https://www.googleapis.com/oauth2/v1/certs"
),
client_x509_cert_url=(
"https://www.googleapis.com/robot/v1/metadata/x509/..."
),
universe_domain="googleapis.com",
),
scopes=["https://www.googleapis.com/auth/cloud-platform"],
service_account_credential=sa_credential,
scopes=_DEFAULT_SCOPES,
),
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account token" in str(exc_info.value)
mock_from_service_account_info.assert_called_once()
mock_from_sa_info.assert_called_once()
def test_exchange_raises_when_explicit_credentials_have_no_scopes(
service_account_exchanger, auth_scheme, sa_credential
):
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=sa_credential,
),
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "scopes are required" in str(exc_info.value)
# --- ID token exchange tests ---
def test_exchange_id_token_with_explicit_credentials(
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
):
mock_id_credentials = MagicMock()
mock_id_credentials.token = "mock_id_token"
mock_from_sa_info = MagicMock(return_value=mock_id_credentials)
monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=sa_credential,
scopes=_DEFAULT_SCOPES,
use_id_token=True,
audience="https://my-service.run.app",
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_id_token"
assert result.http.additional_headers is None
mock_from_sa_info.assert_called_once()
assert (
mock_from_sa_info.call_args[1]["target_audience"]
== "https://my-service.run.app"
)
mock_id_credentials.refresh.assert_called_once()
def test_exchange_id_token_with_adc(
service_account_exchanger, auth_scheme, monkeypatch
):
mock_fetch_id_token = MagicMock(return_value="mock_adc_id_token")
monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=_DEFAULT_SCOPES,
use_id_token=True,
audience="https://my-service.run.app",
),
)
result = service_account_exchanger.exchange_credential(
auth_scheme, auth_credential
)
assert result.auth_type == AuthCredentialTypes.HTTP
assert result.http.scheme == "bearer"
assert result.http.credentials.token == "mock_adc_id_token"
assert result.http.additional_headers is None
mock_fetch_id_token.assert_called_once()
assert mock_fetch_id_token.call_args[0][1] == "https://my-service.run.app"
def test_id_token_requires_audience():
with pytest.raises(
ValueError, match="audience is required when use_id_token is True"
):
ServiceAccount(
use_default_credential=True,
use_id_token=True,
)
def test_exchange_id_token_wraps_error_with_explicit_credentials(
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
):
mock_from_sa_info = MagicMock(
side_effect=ValueError("Failed to create ID token credentials")
)
monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
service_account_credential=sa_credential,
scopes=_DEFAULT_SCOPES,
use_id_token=True,
audience="https://my-service.run.app",
),
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account for ID token" in str(
exc_info.value
)
def test_exchange_id_token_wraps_error_with_adc(
service_account_exchanger, auth_scheme, monkeypatch
):
mock_fetch_id_token = MagicMock(
side_effect=google_auth_exceptions.DefaultCredentialsError(
"Metadata service unavailable"
)
)
monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(
use_default_credential=True,
scopes=_DEFAULT_SCOPES,
use_id_token=True,
audience="https://my-service.run.app",
),
)
with pytest.raises(AuthCredentialMissingError) as exc_info:
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
assert "Failed to exchange service account for ID token" in str(
exc_info.value
)
# --- Model validator tests ---
def test_model_validator_rejects_missing_credential_without_adc():
with pytest.raises(
ValueError,
match="service_account_credential is required",
):
ServiceAccount(
use_default_credential=False,
scopes=_DEFAULT_SCOPES,
)
def test_model_validator_allows_adc_without_explicit_credential():
sa = ServiceAccount(
use_default_credential=True,
scopes=_DEFAULT_SCOPES,
)
assert sa.service_account_credential is None
assert sa.use_default_credential is True