From 69997cd5ef44ee881a974bb36dc100e17ed6de2e Mon Sep 17 00:00:00 2001 From: Marlin Ranasinghe <77016115+MarlzRana@users.noreply.github.com> Date: Fri, 12 Dec 2025 10:57:31 -0800 Subject: [PATCH] fix: oauth refresh not triggered on token expiry Merge https://github.com/google/adk-python/pull/3767 Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3767 from MarlzRana:marlzrana/fix-oauth-refresh-not-triggered-on-token-expiry 2dae3917de6e8d3fa5317857d06305d47c0773b0 PiperOrigin-RevId: 843756363 --- src/google/adk/auth/auth_handler.py | 3 +- src/google/adk/auth/credential_manager.py | 17 ++--- .../exchanger/base_credential_exchanger.py | 14 ++++- .../exchanger/oauth2_credential_exchanger.py | 38 +++++++----- .../test_oauth2_credential_exchanger.py | 62 +++++++++---------- 5 files changed, 75 insertions(+), 59 deletions(-) diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 07515ab2..d472bff1 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -48,9 +48,10 @@ class AuthHandler: self, ) -> AuthCredential: exchanger = OAuth2CredentialExchanger() - return await exchanger.exchange( + exchange_result = await exchanger.exchange( self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme ) + return exchange_result.credential async def parse_and_store_auth_response(self, state: State) -> None: diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index c022ab69..2497c7b6 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -29,6 +29,7 @@ from .auth_schemes import ExtendedOAuth2 from .auth_schemes import OpenIdConnectWithConfig from .auth_tool import AuthConfig from .exchanger.base_credential_exchanger import BaseCredentialExchanger +from .exchanger.base_credential_exchanger import ExchangeResult from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry from .oauth2_discovery import OAuth2DiscoveryManager from .refresher.credential_refresher_registry import CredentialRefresherRegistry @@ -214,15 +215,17 @@ class CredentialManager: return credential, False if isinstance(exchanger, ServiceAccountCredentialExchanger): - exchanged_credential = exchanger.exchange_credential( - self._auth_config.auth_scheme, credential - ) - else: - exchanged_credential = await exchanger.exchange( - credential, self._auth_config.auth_scheme + return ( + exchanger.exchange_credential( + self._auth_config.auth_scheme, credential + ), + True, ) - return exchanged_credential, True + exchange_result = await exchanger.exchange( + credential, self._auth_config.auth_scheme + ) + return exchange_result.credential, exchange_result.was_exchanged async def _refresh_credential( self, credential: AuthCredential diff --git a/src/google/adk/auth/exchanger/base_credential_exchanger.py b/src/google/adk/auth/exchanger/base_credential_exchanger.py index 31106b55..a9d79aed 100644 --- a/src/google/adk/auth/exchanger/base_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/base_credential_exchanger.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +from typing import NamedTuple from typing import Optional from ...utils.feature_decorator import experimental @@ -28,6 +29,11 @@ class CredentialExchangeError(Exception): """Base exception for credential exchange errors.""" +class ExchangeResult(NamedTuple): + credential: AuthCredential + was_exchanged: bool + + @experimental class BaseCredentialExchanger(abc.ABC): """Base interface for credential exchangers. @@ -41,15 +47,17 @@ class BaseCredentialExchanger(abc.ABC): self, auth_credential: AuthCredential, auth_scheme: Optional[AuthScheme] = None, - ) -> AuthCredential: + ) -> ExchangeResult: """Exchange credential if needed. Args: auth_credential: The credential to exchange. - auth_scheme: The authentication scheme (optional, some exchangers don't need it). + auth_scheme: The authentication scheme (optional, some exchangers don't + need it). Returns: - The exchanged credential. + An ExchangeResult object containing the exchanged credential and a + boolean indicating whether the credential was exchanged. Raises: CredentialExchangeError: If credential exchange fails. diff --git a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py index 2adf0a81..0744e523 100644 --- a/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py +++ b/src/google/adk/auth/exchanger/oauth2_credential_exchanger.py @@ -31,6 +31,7 @@ from typing_extensions import override from .base_credential_exchanger import BaseCredentialExchanger from .base_credential_exchanger import CredentialExchangeError +from .base_credential_exchanger import ExchangeResult try: from authlib.integrations.requests_client import OAuth2Session @@ -51,7 +52,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): self, auth_credential: AuthCredential, auth_scheme: Optional[AuthScheme] = None, - ) -> AuthCredential: + ) -> ExchangeResult: """Exchange OAuth2 credential from authorization response. if credential exchange failed, the original credential will be returned. @@ -61,7 +62,8 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): auth_scheme: The OAuth2 authentication scheme. Returns: - The exchanged credential with access token. + An ExchangeResult object containing the exchanged credential and a + boolean indicating whether the credential was exchanged. Raises: CredentialExchangeError: If auth_scheme is missing. @@ -79,10 +81,10 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): logger.warning( "authlib is not available, skipping OAuth2 credential exchange." ) - return auth_credential + return ExchangeResult(auth_credential, False) if auth_credential.oauth2 and auth_credential.oauth2.access_token: - return auth_credential + return ExchangeResult(auth_credential, False) # Determine grant type from auth_scheme grant_type = self._determine_grant_type(auth_scheme) @@ -97,7 +99,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): ) else: logger.warning("Unsupported OAuth2 grant type: %s", grant_type) - return auth_credential + return ExchangeResult(auth_credential, False) def _determine_grant_type( self, auth_scheme: AuthScheme @@ -129,7 +131,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): self, auth_credential: AuthCredential, auth_scheme: AuthScheme, - ) -> AuthCredential: + ) -> ExchangeResult: """Exchange client credentials for access token. Args: @@ -137,14 +139,15 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): auth_scheme: The OAuth2 authentication scheme. Returns: - The credential with access token. + An ExchangeResult object containing the exchanged credential and a + boolean indicating whether the credential was exchanged. """ client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) if not client: logger.warning( "Could not create OAuth2 session for client credentials exchange" ) - return auth_credential + return ExchangeResult(auth_credential, False) try: tokens = client.fetch_token( @@ -155,13 +158,13 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): logger.debug("Successfully exchanged client credentials for access token") except Exception as e: logger.error("Failed to exchange client credentials: %s", e) - return auth_credential + return ExchangeResult(auth_credential, False) - return auth_credential + return ExchangeResult(auth_credential, True) def _normalize_auth_uri(self, auth_uri: str | None) -> str | None: - # Authlib currently used a simplified token check by simply scanning hash existence, - # yet itself might sometimes add extraneous hashes. + # Authlib currently used a simplified token check by simply scanning hash + # existence, yet itself might sometimes add extraneous hashes. # Drop trailing empty hash if seen. if auth_uri and auth_uri.endswith("#"): return auth_uri[:-1] @@ -171,7 +174,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): self, auth_credential: AuthCredential, auth_scheme: AuthScheme, - ) -> AuthCredential: + ) -> ExchangeResult: """Exchange authorization code for access token. Args: @@ -179,14 +182,15 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): auth_scheme: The OAuth2 authentication scheme. Returns: - The credential with access token. + An ExchangeResult object containing the exchanged credential and a + boolean indicating whether the credential was exchanged. """ client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential) if not client: logger.warning( "Could not create OAuth2 session for authorization code exchange" ) - return auth_credential + return ExchangeResult(auth_credential, False) try: tokens = client.fetch_token( @@ -202,6 +206,6 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger): logger.debug("Successfully exchanged authorization code for access token") except Exception as e: logger.error("Failed to exchange authorization code: %s", e) - return auth_credential + return ExchangeResult(auth_credential, False) - return auth_credential + return ExchangeResult(auth_credential, True) diff --git a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py index 84191b87..6762710c 100644 --- a/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py +++ b/tests/unittests/auth/exchanger/test_oauth2_credential_exchanger.py @@ -33,7 +33,6 @@ import pytest class TestOAuth2CredentialExchanger: """Test suite for OAuth2CredentialExchanger.""" - @pytest.mark.asyncio async def test_exchange_with_existing_token(self): """Test exchange method when access token already exists.""" scheme = OpenIdConnectWithConfig( @@ -55,14 +54,14 @@ class TestOAuth2CredentialExchanger: ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Should return the same credential since access token already exists - assert result == credential - assert result.oauth2.access_token == "existing_token" + assert exchange_result.credential == credential + assert exchange_result.credential.oauth2.access_token == "existing_token" + assert not exchange_result.was_exchanged @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_success(self, mock_oauth2_session): """Test successful token exchange.""" # Setup mock @@ -96,14 +95,16 @@ class TestOAuth2CredentialExchanger: ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Verify token exchange was successful - assert result.oauth2.access_token == "new_access_token" - assert result.oauth2.refresh_token == "new_refresh_token" + assert exchange_result.credential.oauth2.access_token == "new_access_token" + assert ( + exchange_result.credential.oauth2.refresh_token == "new_refresh_token" + ) + assert exchange_result.was_exchanged mock_client.fetch_token.assert_called_once() - @pytest.mark.asyncio async def test_exchange_missing_auth_scheme(self): """Test exchange with missing auth_scheme raises ValueError.""" credential = AuthCredential( @@ -122,7 +123,6 @@ class TestOAuth2CredentialExchanger: assert "auth_scheme is required" in str(e) @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_no_session(self, mock_oauth2_session): """Test exchange when OAuth2Session cannot be created.""" # Mock to return None for create_oauth2_session @@ -146,14 +146,14 @@ class TestOAuth2CredentialExchanger: ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Should return original credential when session creation fails - assert result == credential - assert result.oauth2.access_token is None + assert exchange_result.credential == credential + assert exchange_result.credential.oauth2.access_token is None + assert not exchange_result.was_exchanged @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_fetch_token_failure(self, mock_oauth2_session): """Test exchange when fetch_token fails.""" # Setup mock to raise exception during fetch_token @@ -181,14 +181,14 @@ class TestOAuth2CredentialExchanger: ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Should return original credential when fetch_token fails - assert result == credential - assert result.oauth2.access_token is None + assert exchange_result.credential == credential + assert exchange_result.credential.oauth2.access_token is None + assert not exchange_result.was_exchanged mock_client.fetch_token.assert_called_once() - @pytest.mark.asyncio async def test_exchange_authlib_not_available(self): """Test exchange when authlib is not available.""" scheme = OpenIdConnectWithConfig( @@ -217,14 +217,14 @@ class TestOAuth2CredentialExchanger: "google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE", False, ): - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Should return original credential when authlib is not available - assert result == credential - assert result.oauth2.access_token is None + assert exchange_result.credential == credential + assert exchange_result.credential.oauth2.access_token is None + assert not exchange_result.was_exchanged @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_client_credentials_success(self, mock_oauth2_session): """Test successful client credentials exchange.""" # Setup mock @@ -255,17 +255,19 @@ class TestOAuth2CredentialExchanger: ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Verify client credentials exchange was successful - assert result.oauth2.access_token == "client_access_token" + assert ( + exchange_result.credential.oauth2.access_token == "client_access_token" + ) + assert exchange_result.was_exchanged mock_client.fetch_token.assert_called_once_with( "https://example.com/token", grant_type="client_credentials", ) @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_client_credentials_failure(self, mock_oauth2_session): """Test client credentials exchange failure.""" # Setup mock to raise exception during fetch_token @@ -292,15 +294,15 @@ class TestOAuth2CredentialExchanger: ) exchanger = OAuth2CredentialExchanger() - result = await exchanger.exchange(credential, scheme) + exchange_result = await exchanger.exchange(credential, scheme) # Should return original credential when client credentials exchange fails - assert result == credential - assert result.oauth2.access_token is None + assert exchange_result.credential == credential + assert exchange_result.credential.oauth2.access_token is None + assert not exchange_result.was_exchanged mock_client.fetch_token.assert_called_once() @patch("google.adk.auth.oauth2_credential_util.OAuth2Session") - @pytest.mark.asyncio async def test_exchange_normalize_uri(self, mock_oauth2_session): """Test exchange method normalizes auth_response_uri.""" mock_client = Mock() @@ -344,7 +346,6 @@ class TestOAuth2CredentialExchanger: client_id="test_client_id", ) - @pytest.mark.asyncio async def test_determine_grant_type_client_credentials(self): """Test grant type determination for client credentials.""" flows = OAuthFlows( @@ -361,7 +362,6 @@ class TestOAuth2CredentialExchanger: assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS - @pytest.mark.asyncio async def test_determine_grant_type_openid_connect(self): """Test grant type determination for OpenID Connect (defaults to auth code).""" scheme = OpenIdConnectWithConfig(