feat: Introduce ExtendedOAuth2 scheme that auto-populates auth/token URLs

Use auto-discovered auth_endpoint and token_endpoint in CredentialManager.

PiperOrigin-RevId: 811183929
This commit is contained in:
Google Team Member
2025-09-24 22:20:37 -07:00
committed by Copybara-Service
parent f159bd9c87
commit c8c6cd70a4
3 changed files with 199 additions and 1 deletions
+12
View File
@@ -12,17 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import Enum
from typing import List
from typing import Optional
from typing import Union
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlows
from fastapi.openapi.models import SecurityBase
from fastapi.openapi.models import SecurityScheme
from fastapi.openapi.models import SecuritySchemeType
from pydantic import Field
from ..utils.feature_decorator import experimental
class OpenIdConnectWithConfig(SecurityBase):
type_: SecuritySchemeType = Field(
@@ -65,3 +70,10 @@ class OAuthGrantType(str, Enum):
# AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0.
AuthSchemeType = SecuritySchemeType
@experimental
class ExtendedOAuth2(OAuth2):
"""OAuth2 scheme that incorporates auto-discovery for endpoints."""
issuer_url: Optional[str] = None # Used for endpoint-discovery
+70 -1
View File
@@ -14,19 +14,26 @@
from __future__ import annotations
import logging
from typing import Optional
from fastapi.openapi.models import OAuth2
from ..agents.callback_context import CallbackContext
from ..utils.feature_decorator import experimental
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
from .auth_schemes import AuthSchemeType
from .auth_schemes import ExtendedOAuth2
from .auth_tool import AuthConfig
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
from .oauth2_discovery import OAuth2DiscoveryManager
from .refresher.base_credential_refresher import BaseCredentialRefresher
from .refresher.credential_refresher_registry import CredentialRefresherRegistry
logger = logging.getLogger("google_adk." + __name__)
@experimental
class CredentialManager:
@@ -74,6 +81,7 @@ class CredentialManager:
self._auth_config = auth_config
self._exchanger_registry = CredentialExchangerRegistry()
self._refresher_registry = CredentialRefresherRegistry()
self._discovery_manager = OAuth2DiscoveryManager()
# Register default exchangers and refreshers
# TODO: support service account credential exchanger
@@ -247,7 +255,14 @@ class CredentialManager:
"auth_config.raw_credential.oauth2 required for credential type "
f"{raw_credential.auth_type}"
)
# Additional validation can be added here
if self._missing_oauth_info() and not await self._populate_auth_scheme():
raise ValueError(
"OAuth scheme info is missing, and auto-discovery has failed to fill"
" them in."
)
# Additional validation can be added here
async def _save_credential(
self, callback_context: CallbackContext, credential: AuthCredential
@@ -259,3 +274,57 @@ class CredentialManager:
credential_service = callback_context._invocation_context.credential_service
if credential_service:
await callback_context.save_credential(self._auth_config)
async def _populate_auth_scheme(self) -> bool:
"""Auto-discover server metadata and populate missing auth scheme info.
Returns:
True if auto-discovery was successful, False otherwise.
"""
auth_scheme = self._auth_config.auth_scheme
if (
not isinstance(auth_scheme, ExtendedOAuth2)
or not auth_scheme.issuer_url
):
logger.warning("No issuer_url was provided for auto-discovery.")
return False
metadata = await self._discovery_manager.discover_auth_server_metadata(
auth_scheme.issuer_url
)
if not metadata:
logger.warning("Auto-discovery has failed to populate OAuth scheme info.")
return False
flows = auth_scheme.flows
if flows.implicit and not flows.implicit.authorizationUrl:
flows.implicit.authorizationUrl = metadata.authorization_endpoint
if flows.password and not flows.password.tokenUrl:
flows.password.tokenUrl = metadata.token_endpoint
if flows.clientCredentials and not flows.clientCredentials.tokenUrl:
flows.clientCredentials.tokenUrl = metadata.token_endpoint
if flows.authorizationCode and not flows.authorizationCode.authorizationUrl:
flows.authorizationCode.authorizationUrl = metadata.authorization_endpoint
if flows.authorizationCode and not flows.authorizationCode.tokenUrl:
flows.authorizationCode.tokenUrl = metadata.token_endpoint
return True
def _missing_oauth_info(self) -> bool:
"""Checks if we are missing auth/token URLs needed for OAuth."""
auth_scheme = self._auth_config.auth_scheme
if isinstance(auth_scheme, OAuth2):
flows = auth_scheme.flows
return (
flows.implicit
and not flows.implicit.authorizationUrl
or flows.password
and not flows.password.tokenUrl
or flows.clientCredentials
and not flows.clientCredentials.tokenUrl
or flows.authorizationCode
and not flows.authorizationCode.authorizationUrl
or flows.authorizationCode
and not flows.authorizationCode.tokenUrl
)
return False
@@ -16,6 +16,10 @@ from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlowImplicit
from fastapi.openapi.models import OAuthFlows
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
@@ -23,8 +27,10 @@ from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_credential import ServiceAccountCredential
from google.adk.auth.auth_schemes import AuthScheme
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.auth.auth_schemes import ExtendedOAuth2
from google.adk.auth.auth_tool import AuthConfig
from google.adk.auth.credential_manager import CredentialManager
from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata
import pytest
@@ -390,6 +396,28 @@ class TestCredentialManager:
with pytest.raises(ValueError, match="oauth2 required for credential type"):
await manager._validate_credential()
@pytest.mark.asyncio
async def test_validate_credential_oauth2_missing_scheme_info(
self, extended_oauth2_scheme
):
"""Test _validate_credential with OAuth2 missing scheme info."""
mock_raw_credential = Mock(spec=AuthCredential)
mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2
mock_raw_credential.oauth2 = Mock(spec=OAuth2Auth)
auth_config = Mock(spec=AuthConfig)
auth_config.raw_auth_credential = mock_raw_credential
auth_config.auth_scheme = extended_oauth2_scheme
manager = CredentialManager(auth_config)
with patch.object(
manager,
"_populate_auth_scheme",
return_value=False,
) and pytest.raises(ValueError, match="OAuth scheme info is missing"):
await manager._validate_credential()
@pytest.mark.asyncio
async def test_exchange_credentials_service_account(self):
"""Test _exchange_credential with service account credential."""
@@ -445,6 +473,95 @@ class TestCredentialManager:
assert result == mock_credential
assert was_exchanged is False
@pytest.fixture
def auth_server_metadata(self):
"""Create AuthorizationServerMetadata object."""
return AuthorizationServerMetadata(
issuer="https://auth.example.com",
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
scopes_supported=["read", "write"],
)
@pytest.fixture
def extended_oauth2_scheme(self):
"""Create ExtendedOAuth2 object with empty endpoints."""
return ExtendedOAuth2(
issuer_url="https://auth.example.com",
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="",
tokenUrl="",
)
),
)
@pytest.fixture
def implicit_oauth2_scheme(self):
"""Create OAuth2 object with implicit flow."""
return OAuth2(
flows=OAuthFlows(
implicit=OAuthFlowImplicit(
authorizationUrl="https://auth.example.com/authorize"
)
)
)
@pytest.mark.asyncio
async def test_populate_auth_scheme_success(
self, auth_server_metadata, extended_oauth2_scheme
):
"""Test _populate_auth_scheme successfully populates missing info."""
auth_config = Mock(spec=AuthConfig)
auth_config.auth_scheme = extended_oauth2_scheme
manager = CredentialManager(auth_config)
with patch.object(
manager._discovery_manager,
"discover_auth_server_metadata",
return_value=auth_server_metadata,
):
assert await manager._populate_auth_scheme()
assert (
manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl
== "https://auth.example.com/authorize"
)
assert (
manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl
== "https://auth.example.com/token"
)
@pytest.mark.asyncio
async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme):
"""Test _populate_auth_scheme when auto-discovery fails."""
auth_config = Mock(spec=AuthConfig)
auth_config.auth_scheme = extended_oauth2_scheme
manager = CredentialManager(auth_config)
with patch.object(
manager._discovery_manager,
"discover_auth_server_metadata",
return_value=None,
):
assert not await manager._populate_auth_scheme()
assert (
not manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl
)
assert not manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl
@pytest.mark.asyncio
async def test_populate_auth_scheme_noop(self, implicit_oauth2_scheme):
"""Test _populate_auth_scheme when auth scheme info not missing."""
auth_config = Mock(spec=AuthConfig)
auth_config.auth_scheme = implicit_oauth2_scheme
manager = CredentialManager(auth_config)
assert not await manager._populate_auth_scheme() # no-op
assert manager._auth_config.auth_scheme == implicit_oauth2_scheme
@pytest.fixture
def oauth2_auth_scheme():