You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support ID token exchange in ServiceAccountCredentialExchanger
Adds use_id_token and audience fields to ServiceAccount so that ServiceAccountCredentialExchanger can produce ID tokens instead of access tokens. This is required for authenticating to Cloud Run, Cloud Functions, and other Google Cloud services that verify caller identity. Close #4458 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 874630210
This commit is contained in:
committed by
Copybara-Service
parent
c615757ba1
commit
7be90db24b
@@ -25,6 +25,7 @@ from pydantic import alias_generators
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
|
||||
class BaseModelWithConfig(BaseModel):
|
||||
@@ -145,11 +146,45 @@ class ServiceAccountCredential(BaseModelWithConfig):
|
||||
|
||||
|
||||
class ServiceAccount(BaseModelWithConfig):
|
||||
"""Represents Google Service Account configuration."""
|
||||
"""Represents Google Service Account configuration.
|
||||
|
||||
Attributes:
|
||||
service_account_credential: The service account credential (JSON key).
|
||||
scopes: The OAuth2 scopes to request. Optional; when omitted with
|
||||
``use_default_credential=True``, defaults to the cloud-platform scope.
|
||||
use_default_credential: Whether to use Application Default Credentials.
|
||||
use_id_token: Whether to exchange for an ID token instead of an access
|
||||
token. Required for service-to-service authentication with Cloud Run,
|
||||
Cloud Functions, and other Google Cloud services that require identity
|
||||
verification. When True, ``audience`` must also be set.
|
||||
audience: The target audience for the ID token, typically the URL of the
|
||||
receiving service (e.g. ``https://my-service-xyz.run.app``). Required
|
||||
when ``use_id_token`` is True.
|
||||
"""
|
||||
|
||||
service_account_credential: Optional[ServiceAccountCredential] = None
|
||||
scopes: List[str]
|
||||
scopes: Optional[List[str]] = None
|
||||
use_default_credential: Optional[bool] = False
|
||||
use_id_token: Optional[bool] = False
|
||||
audience: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_config(self) -> ServiceAccount:
|
||||
if (
|
||||
not self.use_default_credential
|
||||
and self.service_account_credential is None
|
||||
):
|
||||
raise ValueError(
|
||||
"service_account_credential is required when"
|
||||
" use_default_credential is False."
|
||||
)
|
||||
if self.use_id_token and not self.audience:
|
||||
raise ValueError(
|
||||
"audience is required when use_id_token is True. Set it to the"
|
||||
" URL of the target service"
|
||||
" (e.g. 'https://my-service.run.app')."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AuthCredentialTypes(str, Enum):
|
||||
|
||||
+113
-27
@@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
from typing import Optional
|
||||
|
||||
import google.auth
|
||||
from google.auth import exceptions as google_auth_exceptions
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2 import service_account
|
||||
import google.oauth2.credentials
|
||||
@@ -27,6 +28,7 @@ from .....auth.auth_credential import AuthCredential
|
||||
from .....auth.auth_credential import AuthCredentialTypes
|
||||
from .....auth.auth_credential import HttpAuth
|
||||
from .....auth.auth_credential import HttpCredentials
|
||||
from .....auth.auth_credential import ServiceAccount
|
||||
from .....auth.auth_schemes import AuthScheme
|
||||
from .base_credential_exchanger import AuthCredentialMissingError
|
||||
from .base_credential_exchanger import BaseAuthCredentialExchanger
|
||||
@@ -38,6 +40,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
Uses the default service credential if `use_default_credential = True`.
|
||||
Otherwise, uses the service account credential provided in the auth
|
||||
credential.
|
||||
|
||||
Supports exchanging for either an access token (default) or an ID token
|
||||
when ``ServiceAccount.use_id_token`` is True. ID tokens are required for
|
||||
service-to-service authentication with Cloud Run, Cloud Functions, and
|
||||
other services that verify caller identity.
|
||||
"""
|
||||
|
||||
def exchange_credential(
|
||||
@@ -45,52 +52,130 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
auth_scheme: AuthScheme,
|
||||
auth_credential: Optional[AuthCredential] = None,
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the service account auth credential for an access token.
|
||||
"""Exchanges the service account auth credential for a token.
|
||||
|
||||
If auth_credential contains a service account credential, it will be used
|
||||
to fetch an access token. Otherwise, the default service credential will be
|
||||
used for fetching an access token.
|
||||
to fetch a token. Otherwise, the default service credential will be
|
||||
used for fetching a token.
|
||||
|
||||
When ``service_account.use_id_token`` is True, an ID token is fetched
|
||||
using the configured ``audience``. This is required for authenticating
|
||||
to Cloud Run, Cloud Functions, and similar services.
|
||||
|
||||
Args:
|
||||
auth_scheme: The auth scheme.
|
||||
auth_credential: The auth credential.
|
||||
|
||||
Returns:
|
||||
An AuthCredential in HTTPBearer format, containing the access token.
|
||||
An AuthCredential in HTTPBearer format, containing the token.
|
||||
"""
|
||||
if (
|
||||
auth_credential is None
|
||||
or auth_credential.service_account is None
|
||||
or (
|
||||
auth_credential.service_account.service_account_credential is None
|
||||
and not auth_credential.service_account.use_default_credential
|
||||
)
|
||||
):
|
||||
if auth_credential is None or auth_credential.service_account is None:
|
||||
raise AuthCredentialMissingError(
|
||||
"Service account credentials are missing. Please provide them, or set"
|
||||
" `use_default_credential = True` to use application default"
|
||||
"Service account credentials are missing. Please provide them, or"
|
||||
" set `use_default_credential = True` to use application default"
|
||||
" credential in a hosted service like Cloud Run."
|
||||
)
|
||||
|
||||
sa_config = auth_credential.service_account
|
||||
|
||||
if sa_config.use_id_token:
|
||||
return self._exchange_for_id_token(sa_config)
|
||||
|
||||
return self._exchange_for_access_token(sa_config)
|
||||
|
||||
def _exchange_for_id_token(self, sa_config: ServiceAccount) -> AuthCredential:
|
||||
"""Exchanges the service account credential for an ID token.
|
||||
|
||||
Args:
|
||||
sa_config: The service account configuration.
|
||||
|
||||
Returns:
|
||||
An AuthCredential in HTTPBearer format containing the ID token.
|
||||
|
||||
Raises:
|
||||
AuthCredentialMissingError: If token exchange fails.
|
||||
"""
|
||||
# audience and credential presence are validated by the ServiceAccount
|
||||
# model_validator at construction time.
|
||||
try:
|
||||
if auth_credential.service_account.use_default_credential:
|
||||
credentials, project_id = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
quota_project_id = (
|
||||
getattr(credentials, "quota_project_id", None) or project_id
|
||||
)
|
||||
if sa_config.use_default_credential:
|
||||
from google.oauth2 import id_token as oauth2_id_token
|
||||
|
||||
request = Request()
|
||||
token = oauth2_id_token.fetch_id_token(request, sa_config.audience)
|
||||
else:
|
||||
config = auth_credential.service_account
|
||||
# Guaranteed non-None by ServiceAccount model_validator.
|
||||
assert sa_config.service_account_credential is not None
|
||||
credentials = (
|
||||
service_account.IDTokenCredentials.from_service_account_info(
|
||||
sa_config.service_account_credential.model_dump(),
|
||||
target_audience=sa_config.audience,
|
||||
)
|
||||
)
|
||||
credentials.refresh(Request())
|
||||
token = credentials.token
|
||||
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=token),
|
||||
),
|
||||
)
|
||||
|
||||
# ValueError is raised by google-auth when service account JSON is
|
||||
# missing required fields (e.g. client_email, private_key), or when
|
||||
# fetch_id_token cannot determine credentials from the environment.
|
||||
except (google_auth_exceptions.GoogleAuthError, ValueError) as e:
|
||||
raise AuthCredentialMissingError(
|
||||
f"Failed to exchange service account for ID token: {e}"
|
||||
) from e
|
||||
|
||||
def _exchange_for_access_token(
|
||||
self, sa_config: ServiceAccount
|
||||
) -> AuthCredential:
|
||||
"""Exchanges the service account credential for an access token.
|
||||
|
||||
Args:
|
||||
sa_config: The service account configuration.
|
||||
|
||||
Returns:
|
||||
An AuthCredential in HTTPBearer format containing the access token.
|
||||
|
||||
Raises:
|
||||
AuthCredentialMissingError: If scopes are missing for explicit
|
||||
credentials or token exchange fails.
|
||||
"""
|
||||
if not sa_config.use_default_credential and not sa_config.scopes:
|
||||
raise AuthCredentialMissingError(
|
||||
"scopes are required when using explicit service account credentials"
|
||||
" for access token exchange."
|
||||
)
|
||||
|
||||
try:
|
||||
if sa_config.use_default_credential:
|
||||
scopes = (
|
||||
sa_config.scopes
|
||||
if sa_config.scopes
|
||||
else ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
credentials, project_id = google.auth.default(
|
||||
scopes=scopes,
|
||||
)
|
||||
quota_project_id = credentials.quota_project_id or project_id
|
||||
else:
|
||||
# Guaranteed non-None by ServiceAccount model_validator.
|
||||
assert sa_config.service_account_credential is not None
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
config.service_account_credential.model_dump(), scopes=config.scopes
|
||||
sa_config.service_account_credential.model_dump(),
|
||||
scopes=sa_config.scopes,
|
||||
)
|
||||
quota_project_id = None
|
||||
|
||||
credentials.refresh(Request())
|
||||
|
||||
updated_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token
|
||||
return AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer",
|
||||
credentials=HttpCredentials(token=credentials.token),
|
||||
@@ -101,9 +186,10 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger):
|
||||
else None,
|
||||
),
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
except Exception as e:
|
||||
# ValueError is raised by google-auth when service account JSON is
|
||||
# missing required fields (e.g. client_email, private_key).
|
||||
except (google_auth_exceptions.GoogleAuthError, ValueError) as e:
|
||||
raise AuthCredentialMissingError(
|
||||
f"Failed to exchange service account token: {e}"
|
||||
) from e
|
||||
|
||||
@@ -534,7 +534,9 @@ class TestMCPTool:
|
||||
)
|
||||
|
||||
# Create service account credential
|
||||
service_account = ServiceAccount(scopes=["test"])
|
||||
service_account = ServiceAccount(
|
||||
scopes=["test"], use_default_credential=True
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=service_account,
|
||||
|
||||
+253
-80
@@ -25,8 +25,23 @@ from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError
|
||||
from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
import google.auth
|
||||
from google.auth import exceptions as google_auth_exceptions
|
||||
import pytest
|
||||
|
||||
_ACCESS_TOKEN_MONKEYPATCH_TARGET = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
|
||||
_ID_TOKEN_MONKEYPATCH_TARGET = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.IDTokenCredentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
|
||||
_FETCH_ID_TOKEN_MONKEYPATCH_TARGET = "google.oauth2.id_token.fetch_id_token"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service_account_exchanger():
|
||||
@@ -41,50 +56,45 @@ def auth_scheme():
|
||||
return scheme
|
||||
|
||||
|
||||
def test_exchange_credential_success(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
@pytest.fixture
|
||||
def sa_credential():
|
||||
"""A minimal valid ServiceAccountCredential for testing."""
|
||||
return ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="test_project_id",
|
||||
private_key_id="test_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="test@test.iam.gserviceaccount.com",
|
||||
client_id="test_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs",
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/test"
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
|
||||
|
||||
# --- Access token exchange tests ---
|
||||
|
||||
|
||||
def test_exchange_access_token_with_explicit_credentials(
|
||||
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
|
||||
):
|
||||
"""Test successful exchange of service account credentials."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
mock_from_sa_info = MagicMock(return_value=mock_credentials)
|
||||
monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
|
||||
|
||||
# Mock the from_service_account_info method
|
||||
mock_from_service_account_info = MagicMock(return_value=mock_credentials)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
)
|
||||
|
||||
# Mock the refresh method
|
||||
mock_credentials.refresh = MagicMock()
|
||||
|
||||
# Create a valid AuthCredential with service account info
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
service_account_credential=sa_credential,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -95,7 +105,7 @@ def test_exchange_credential_success(
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
mock_from_sa_info.assert_called_once()
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
@@ -107,7 +117,7 @@ def test_exchange_credential_success(
|
||||
(None, None, None),
|
||||
],
|
||||
)
|
||||
def test_exchange_credential_use_default_credential_success(
|
||||
def test_exchange_access_token_with_adc_sets_quota_project(
|
||||
service_account_exchanger,
|
||||
auth_scheme,
|
||||
monkeypatch,
|
||||
@@ -115,7 +125,6 @@ def test_exchange_credential_use_default_credential_success(
|
||||
adc_project_id,
|
||||
expected_quota_project_id,
|
||||
):
|
||||
"""Test successful exchange of service account credentials using default credential."""
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
mock_credentials.quota_project_id = cred_quota_project_id
|
||||
@@ -128,7 +137,7 @@ def test_exchange_credential_use_default_credential_success(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
scopes=["https://www.googleapis.com/auth/bigquery"],
|
||||
),
|
||||
)
|
||||
|
||||
@@ -146,26 +155,49 @@ def test_exchange_credential_use_default_credential_success(
|
||||
)
|
||||
else:
|
||||
assert not result.http.additional_headers
|
||||
# Verify google.auth.default is called with the correct scopes parameter
|
||||
mock_google_auth_default.assert_called_once_with(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
scopes=["https://www.googleapis.com/auth/bigquery"]
|
||||
)
|
||||
mock_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_credential_missing_auth_credential(
|
||||
def test_exchange_access_token_with_adc_defaults_to_cloud_platform_scope(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_credentials.token = "mock_access_token"
|
||||
mock_credentials.quota_project_id = None
|
||||
mock_google_auth_default = MagicMock(return_value=(mock_credentials, None))
|
||||
monkeypatch.setattr(google.auth, "default", mock_google_auth_default)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_access_token"
|
||||
mock_google_auth_default.assert_called_once_with(scopes=_DEFAULT_SCOPES)
|
||||
|
||||
|
||||
def test_exchange_raises_when_auth_credential_is_none(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing auth credential during exchange."""
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, None)
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_missing_service_account_info(
|
||||
def test_exchange_raises_when_service_account_is_none(
|
||||
service_account_exchanger, auth_scheme
|
||||
):
|
||||
"""Test missing service account info during exchange."""
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
)
|
||||
@@ -174,47 +206,188 @@ def test_exchange_credential_missing_service_account_info(
|
||||
assert "Service account credentials are missing" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_exchange_credential_exchange_failure(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
def test_exchange_wraps_google_auth_error_as_missing_error(
|
||||
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
|
||||
):
|
||||
"""Test failure during service account token exchange."""
|
||||
mock_from_service_account_info = MagicMock(
|
||||
side_effect=Exception("Failed to load credentials")
|
||||
)
|
||||
target_path = (
|
||||
"google.adk.tools.openapi_tool.auth.credential_exchangers."
|
||||
"service_account_exchanger.service_account.Credentials."
|
||||
"from_service_account_info"
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
target_path,
|
||||
mock_from_service_account_info,
|
||||
mock_from_sa_info = MagicMock(
|
||||
side_effect=ValueError("Failed to load credentials")
|
||||
)
|
||||
monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=ServiceAccountCredential(
|
||||
type_="service_account",
|
||||
project_id="your_project_id",
|
||||
private_key_id="your_private_key_id",
|
||||
private_key="-----BEGIN PRIVATE KEY-----...",
|
||||
client_email="...@....iam.gserviceaccount.com",
|
||||
client_id="your_client_id",
|
||||
auth_uri="https://accounts.google.com/o/oauth2/auth",
|
||||
token_uri="https://oauth2.googleapis.com/token",
|
||||
auth_provider_x509_cert_url=(
|
||||
"https://www.googleapis.com/oauth2/v1/certs"
|
||||
),
|
||||
client_x509_cert_url=(
|
||||
"https://www.googleapis.com/robot/v1/metadata/x509/..."
|
||||
),
|
||||
universe_domain="googleapis.com",
|
||||
),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
service_account_credential=sa_credential,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Failed to exchange service account token" in str(exc_info.value)
|
||||
mock_from_service_account_info.assert_called_once()
|
||||
mock_from_sa_info.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_raises_when_explicit_credentials_have_no_scopes(
|
||||
service_account_exchanger, auth_scheme, sa_credential
|
||||
):
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=sa_credential,
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "scopes are required" in str(exc_info.value)
|
||||
|
||||
|
||||
# --- ID token exchange tests ---
|
||||
|
||||
|
||||
def test_exchange_id_token_with_explicit_credentials(
|
||||
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
|
||||
):
|
||||
mock_id_credentials = MagicMock()
|
||||
mock_id_credentials.token = "mock_id_token"
|
||||
mock_from_sa_info = MagicMock(return_value=mock_id_credentials)
|
||||
monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=sa_credential,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
use_id_token=True,
|
||||
audience="https://my-service.run.app",
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_id_token"
|
||||
assert result.http.additional_headers is None
|
||||
mock_from_sa_info.assert_called_once()
|
||||
assert (
|
||||
mock_from_sa_info.call_args[1]["target_audience"]
|
||||
== "https://my-service.run.app"
|
||||
)
|
||||
mock_id_credentials.refresh.assert_called_once()
|
||||
|
||||
|
||||
def test_exchange_id_token_with_adc(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
mock_fetch_id_token = MagicMock(return_value="mock_adc_id_token")
|
||||
monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
use_id_token=True,
|
||||
audience="https://my-service.run.app",
|
||||
),
|
||||
)
|
||||
|
||||
result = service_account_exchanger.exchange_credential(
|
||||
auth_scheme, auth_credential
|
||||
)
|
||||
|
||||
assert result.auth_type == AuthCredentialTypes.HTTP
|
||||
assert result.http.scheme == "bearer"
|
||||
assert result.http.credentials.token == "mock_adc_id_token"
|
||||
assert result.http.additional_headers is None
|
||||
mock_fetch_id_token.assert_called_once()
|
||||
assert mock_fetch_id_token.call_args[0][1] == "https://my-service.run.app"
|
||||
|
||||
|
||||
def test_id_token_requires_audience():
|
||||
with pytest.raises(
|
||||
ValueError, match="audience is required when use_id_token is True"
|
||||
):
|
||||
ServiceAccount(
|
||||
use_default_credential=True,
|
||||
use_id_token=True,
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_id_token_wraps_error_with_explicit_credentials(
|
||||
service_account_exchanger, auth_scheme, sa_credential, monkeypatch
|
||||
):
|
||||
mock_from_sa_info = MagicMock(
|
||||
side_effect=ValueError("Failed to create ID token credentials")
|
||||
)
|
||||
monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
service_account_credential=sa_credential,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
use_id_token=True,
|
||||
audience="https://my-service.run.app",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Failed to exchange service account for ID token" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
def test_exchange_id_token_wraps_error_with_adc(
|
||||
service_account_exchanger, auth_scheme, monkeypatch
|
||||
):
|
||||
mock_fetch_id_token = MagicMock(
|
||||
side_effect=google_auth_exceptions.DefaultCredentialsError(
|
||||
"Metadata service unavailable"
|
||||
)
|
||||
)
|
||||
monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token)
|
||||
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
use_id_token=True,
|
||||
audience="https://my-service.run.app",
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(AuthCredentialMissingError) as exc_info:
|
||||
service_account_exchanger.exchange_credential(auth_scheme, auth_credential)
|
||||
assert "Failed to exchange service account for ID token" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
|
||||
# --- Model validator tests ---
|
||||
|
||||
|
||||
def test_model_validator_rejects_missing_credential_without_adc():
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="service_account_credential is required",
|
||||
):
|
||||
ServiceAccount(
|
||||
use_default_credential=False,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
)
|
||||
|
||||
|
||||
def test_model_validator_allows_adc_without_explicit_credential():
|
||||
sa = ServiceAccount(
|
||||
use_default_credential=True,
|
||||
scopes=_DEFAULT_SCOPES,
|
||||
)
|
||||
assert sa.service_account_credential is None
|
||||
assert sa.use_default_credential is True
|
||||
|
||||
Reference in New Issue
Block a user