From c8c6cd70a4d9e310267bd3b97ff721972ed282aa Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 24 Sep 2025 22:20:37 -0700 Subject: [PATCH] feat: Introduce ExtendedOAuth2 scheme that auto-populates auth/token URLs Use auto-discovered auth_endpoint and token_endpoint in CredentialManager. PiperOrigin-RevId: 811183929 --- src/google/adk/auth/auth_schemes.py | 12 ++ src/google/adk/auth/credential_manager.py | 71 ++++++++++- .../unittests/auth/test_credential_manager.py | 117 ++++++++++++++++++ 3 files changed, 199 insertions(+), 1 deletion(-) diff --git a/src/google/adk/auth/auth_schemes.py b/src/google/adk/auth/auth_schemes.py index baccf648..c170b957 100644 --- a/src/google/adk/auth/auth_schemes.py +++ b/src/google/adk/auth/auth_schemes.py @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from enum import Enum from typing import List from typing import Optional from typing import Union +from fastapi.openapi.models import OAuth2 from fastapi.openapi.models import OAuthFlows from fastapi.openapi.models import SecurityBase from fastapi.openapi.models import SecurityScheme from fastapi.openapi.models import SecuritySchemeType from pydantic import Field +from ..utils.feature_decorator import experimental + class OpenIdConnectWithConfig(SecurityBase): type_: SecuritySchemeType = Field( @@ -65,3 +70,10 @@ class OAuthGrantType(str, Enum): # AuthSchemeType re-exports SecuritySchemeType from OpenAPI 3.0. AuthSchemeType = SecuritySchemeType + + +@experimental +class ExtendedOAuth2(OAuth2): + """OAuth2 scheme that incorporates auto-discovery for endpoints.""" + + issuer_url: Optional[str] = None # Used for endpoint-discovery diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index c5dae9f5..5f51bae8 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -14,19 +14,26 @@ from __future__ import annotations +import logging from typing import Optional +from fastapi.openapi.models import OAuth2 + from ..agents.callback_context import CallbackContext from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_credential import AuthCredentialTypes from .auth_schemes import AuthSchemeType +from .auth_schemes import ExtendedOAuth2 from .auth_tool import AuthConfig from .exchanger.base_credential_exchanger import BaseCredentialExchanger from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry +from .oauth2_discovery import OAuth2DiscoveryManager from .refresher.base_credential_refresher import BaseCredentialRefresher from .refresher.credential_refresher_registry import CredentialRefresherRegistry +logger = logging.getLogger("google_adk." + __name__) + @experimental class CredentialManager: @@ -74,6 +81,7 @@ class CredentialManager: self._auth_config = auth_config self._exchanger_registry = CredentialExchangerRegistry() self._refresher_registry = CredentialRefresherRegistry() + self._discovery_manager = OAuth2DiscoveryManager() # Register default exchangers and refreshers # TODO: support service account credential exchanger @@ -247,7 +255,14 @@ class CredentialManager: "auth_config.raw_credential.oauth2 required for credential type " f"{raw_credential.auth_type}" ) - # Additional validation can be added here + + if self._missing_oauth_info() and not await self._populate_auth_scheme(): + raise ValueError( + "OAuth scheme info is missing, and auto-discovery has failed to fill" + " them in." + ) + + # Additional validation can be added here async def _save_credential( self, callback_context: CallbackContext, credential: AuthCredential @@ -259,3 +274,57 @@ class CredentialManager: credential_service = callback_context._invocation_context.credential_service if credential_service: await callback_context.save_credential(self._auth_config) + + async def _populate_auth_scheme(self) -> bool: + """Auto-discover server metadata and populate missing auth scheme info. + + Returns: + True if auto-discovery was successful, False otherwise. + """ + auth_scheme = self._auth_config.auth_scheme + if ( + not isinstance(auth_scheme, ExtendedOAuth2) + or not auth_scheme.issuer_url + ): + logger.warning("No issuer_url was provided for auto-discovery.") + return False + + metadata = await self._discovery_manager.discover_auth_server_metadata( + auth_scheme.issuer_url + ) + if not metadata: + logger.warning("Auto-discovery has failed to populate OAuth scheme info.") + return False + + flows = auth_scheme.flows + + if flows.implicit and not flows.implicit.authorizationUrl: + flows.implicit.authorizationUrl = metadata.authorization_endpoint + if flows.password and not flows.password.tokenUrl: + flows.password.tokenUrl = metadata.token_endpoint + if flows.clientCredentials and not flows.clientCredentials.tokenUrl: + flows.clientCredentials.tokenUrl = metadata.token_endpoint + if flows.authorizationCode and not flows.authorizationCode.authorizationUrl: + flows.authorizationCode.authorizationUrl = metadata.authorization_endpoint + if flows.authorizationCode and not flows.authorizationCode.tokenUrl: + flows.authorizationCode.tokenUrl = metadata.token_endpoint + return True + + def _missing_oauth_info(self) -> bool: + """Checks if we are missing auth/token URLs needed for OAuth.""" + auth_scheme = self._auth_config.auth_scheme + if isinstance(auth_scheme, OAuth2): + flows = auth_scheme.flows + return ( + flows.implicit + and not flows.implicit.authorizationUrl + or flows.password + and not flows.password.tokenUrl + or flows.clientCredentials + and not flows.clientCredentials.tokenUrl + or flows.authorizationCode + and not flows.authorizationCode.authorizationUrl + or flows.authorizationCode + and not flows.authorizationCode.tokenUrl + ) + return False diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index fd978604..8ecf6946 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -16,6 +16,10 @@ from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlowImplicit +from fastapi.openapi.models import OAuthFlows from google.adk.auth.auth_credential import AuthCredential from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.auth.auth_credential import OAuth2Auth @@ -23,8 +27,10 @@ from google.adk.auth.auth_credential import ServiceAccount from google.adk.auth.auth_credential import ServiceAccountCredential from google.adk.auth.auth_schemes import AuthScheme from google.adk.auth.auth_schemes import AuthSchemeType +from google.adk.auth.auth_schemes import ExtendedOAuth2 from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.credential_manager import CredentialManager +from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata import pytest @@ -390,6 +396,28 @@ class TestCredentialManager: with pytest.raises(ValueError, match="oauth2 required for credential type"): await manager._validate_credential() + @pytest.mark.asyncio + async def test_validate_credential_oauth2_missing_scheme_info( + self, extended_oauth2_scheme + ): + """Test _validate_credential with OAuth2 missing scheme info.""" + mock_raw_credential = Mock(spec=AuthCredential) + mock_raw_credential.auth_type = AuthCredentialTypes.OAUTH2 + mock_raw_credential.oauth2 = Mock(spec=OAuth2Auth) + + auth_config = Mock(spec=AuthConfig) + auth_config.raw_auth_credential = mock_raw_credential + auth_config.auth_scheme = extended_oauth2_scheme + + manager = CredentialManager(auth_config) + + with patch.object( + manager, + "_populate_auth_scheme", + return_value=False, + ) and pytest.raises(ValueError, match="OAuth scheme info is missing"): + await manager._validate_credential() + @pytest.mark.asyncio async def test_exchange_credentials_service_account(self): """Test _exchange_credential with service account credential.""" @@ -445,6 +473,95 @@ class TestCredentialManager: assert result == mock_credential assert was_exchanged is False + @pytest.fixture + def auth_server_metadata(self): + """Create AuthorizationServerMetadata object.""" + return AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + scopes_supported=["read", "write"], + ) + + @pytest.fixture + def extended_oauth2_scheme(self): + """Create ExtendedOAuth2 object with empty endpoints.""" + return ExtendedOAuth2( + issuer_url="https://auth.example.com", + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="", + tokenUrl="", + ) + ), + ) + + @pytest.fixture + def implicit_oauth2_scheme(self): + """Create OAuth2 object with implicit flow.""" + return OAuth2( + flows=OAuthFlows( + implicit=OAuthFlowImplicit( + authorizationUrl="https://auth.example.com/authorize" + ) + ) + ) + + @pytest.mark.asyncio + async def test_populate_auth_scheme_success( + self, auth_server_metadata, extended_oauth2_scheme + ): + """Test _populate_auth_scheme successfully populates missing info.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = extended_oauth2_scheme + + manager = CredentialManager(auth_config) + with patch.object( + manager._discovery_manager, + "discover_auth_server_metadata", + return_value=auth_server_metadata, + ): + assert await manager._populate_auth_scheme() + + assert ( + manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl + == "https://auth.example.com/authorize" + ) + assert ( + manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl + == "https://auth.example.com/token" + ) + + @pytest.mark.asyncio + async def test_populate_auth_scheme_fail(self, extended_oauth2_scheme): + """Test _populate_auth_scheme when auto-discovery fails.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = extended_oauth2_scheme + + manager = CredentialManager(auth_config) + with patch.object( + manager._discovery_manager, + "discover_auth_server_metadata", + return_value=None, + ): + assert not await manager._populate_auth_scheme() + + assert ( + not manager._auth_config.auth_scheme.flows.authorizationCode.authorizationUrl + ) + assert not manager._auth_config.auth_scheme.flows.authorizationCode.tokenUrl + + @pytest.mark.asyncio + async def test_populate_auth_scheme_noop(self, implicit_oauth2_scheme): + """Test _populate_auth_scheme when auth scheme info not missing.""" + auth_config = Mock(spec=AuthConfig) + auth_config.auth_scheme = implicit_oauth2_scheme + + manager = CredentialManager(auth_config) + assert not await manager._populate_auth_scheme() # no-op + + assert manager._auth_config.auth_scheme == implicit_oauth2_scheme + @pytest.fixture def oauth2_auth_scheme():