fix: oauth refresh not triggered on token expiry

Merge https://github.com/google/adk-python/pull/3767

Co-authored-by: Xuan Yang <xygoogle@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3767 from MarlzRana:marlzrana/fix-oauth-refresh-not-triggered-on-token-expiry 2dae3917de6e8d3fa5317857d06305d47c0773b0
PiperOrigin-RevId: 843756363
This commit is contained in:
Marlin Ranasinghe
2025-12-12 10:57:31 -08:00
committed by Copybara-Service
parent 5c4bae7ff2
commit 69997cd5ef
5 changed files with 75 additions and 59 deletions
+2 -1
View File
@@ -48,9 +48,10 @@ class AuthHandler:
self,
) -> AuthCredential:
exchanger = OAuth2CredentialExchanger()
return await exchanger.exchange(
exchange_result = await exchanger.exchange(
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
)
return exchange_result.credential
async def parse_and_store_auth_response(self, state: State) -> None:
+10 -7
View File
@@ -29,6 +29,7 @@ from .auth_schemes import ExtendedOAuth2
from .auth_schemes import OpenIdConnectWithConfig
from .auth_tool import AuthConfig
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
from .exchanger.base_credential_exchanger import ExchangeResult
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
from .oauth2_discovery import OAuth2DiscoveryManager
from .refresher.credential_refresher_registry import CredentialRefresherRegistry
@@ -214,15 +215,17 @@ class CredentialManager:
return credential, False
if isinstance(exchanger, ServiceAccountCredentialExchanger):
exchanged_credential = exchanger.exchange_credential(
self._auth_config.auth_scheme, credential
)
else:
exchanged_credential = await exchanger.exchange(
credential, self._auth_config.auth_scheme
return (
exchanger.exchange_credential(
self._auth_config.auth_scheme, credential
),
True,
)
return exchanged_credential, True
exchange_result = await exchanger.exchange(
credential, self._auth_config.auth_scheme
)
return exchange_result.credential, exchange_result.was_exchanged
async def _refresh_credential(
self, credential: AuthCredential
@@ -17,6 +17,7 @@
from __future__ import annotations
import abc
from typing import NamedTuple
from typing import Optional
from ...utils.feature_decorator import experimental
@@ -28,6 +29,11 @@ class CredentialExchangeError(Exception):
"""Base exception for credential exchange errors."""
class ExchangeResult(NamedTuple):
credential: AuthCredential
was_exchanged: bool
@experimental
class BaseCredentialExchanger(abc.ABC):
"""Base interface for credential exchangers.
@@ -41,15 +47,17 @@ class BaseCredentialExchanger(abc.ABC):
self,
auth_credential: AuthCredential,
auth_scheme: Optional[AuthScheme] = None,
) -> AuthCredential:
) -> ExchangeResult:
"""Exchange credential if needed.
Args:
auth_credential: The credential to exchange.
auth_scheme: The authentication scheme (optional, some exchangers don't need it).
auth_scheme: The authentication scheme (optional, some exchangers don't
need it).
Returns:
The exchanged credential.
An ExchangeResult object containing the exchanged credential and a
boolean indicating whether the credential was exchanged.
Raises:
CredentialExchangeError: If credential exchange fails.
@@ -31,6 +31,7 @@ from typing_extensions import override
from .base_credential_exchanger import BaseCredentialExchanger
from .base_credential_exchanger import CredentialExchangeError
from .base_credential_exchanger import ExchangeResult
try:
from authlib.integrations.requests_client import OAuth2Session
@@ -51,7 +52,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
self,
auth_credential: AuthCredential,
auth_scheme: Optional[AuthScheme] = None,
) -> AuthCredential:
) -> ExchangeResult:
"""Exchange OAuth2 credential from authorization response.
if credential exchange failed, the original credential will be returned.
@@ -61,7 +62,8 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
auth_scheme: The OAuth2 authentication scheme.
Returns:
The exchanged credential with access token.
An ExchangeResult object containing the exchanged credential and a
boolean indicating whether the credential was exchanged.
Raises:
CredentialExchangeError: If auth_scheme is missing.
@@ -79,10 +81,10 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
logger.warning(
"authlib is not available, skipping OAuth2 credential exchange."
)
return auth_credential
return ExchangeResult(auth_credential, False)
if auth_credential.oauth2 and auth_credential.oauth2.access_token:
return auth_credential
return ExchangeResult(auth_credential, False)
# Determine grant type from auth_scheme
grant_type = self._determine_grant_type(auth_scheme)
@@ -97,7 +99,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
)
else:
logger.warning("Unsupported OAuth2 grant type: %s", grant_type)
return auth_credential
return ExchangeResult(auth_credential, False)
def _determine_grant_type(
self, auth_scheme: AuthScheme
@@ -129,7 +131,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
self,
auth_credential: AuthCredential,
auth_scheme: AuthScheme,
) -> AuthCredential:
) -> ExchangeResult:
"""Exchange client credentials for access token.
Args:
@@ -137,14 +139,15 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
auth_scheme: The OAuth2 authentication scheme.
Returns:
The credential with access token.
An ExchangeResult object containing the exchanged credential and a
boolean indicating whether the credential was exchanged.
"""
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
if not client:
logger.warning(
"Could not create OAuth2 session for client credentials exchange"
)
return auth_credential
return ExchangeResult(auth_credential, False)
try:
tokens = client.fetch_token(
@@ -155,13 +158,13 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
logger.debug("Successfully exchanged client credentials for access token")
except Exception as e:
logger.error("Failed to exchange client credentials: %s", e)
return auth_credential
return ExchangeResult(auth_credential, False)
return auth_credential
return ExchangeResult(auth_credential, True)
def _normalize_auth_uri(self, auth_uri: str | None) -> str | None:
# Authlib currently used a simplified token check by simply scanning hash existence,
# yet itself might sometimes add extraneous hashes.
# Authlib currently used a simplified token check by simply scanning hash
# existence, yet itself might sometimes add extraneous hashes.
# Drop trailing empty hash if seen.
if auth_uri and auth_uri.endswith("#"):
return auth_uri[:-1]
@@ -171,7 +174,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
self,
auth_credential: AuthCredential,
auth_scheme: AuthScheme,
) -> AuthCredential:
) -> ExchangeResult:
"""Exchange authorization code for access token.
Args:
@@ -179,14 +182,15 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
auth_scheme: The OAuth2 authentication scheme.
Returns:
The credential with access token.
An ExchangeResult object containing the exchanged credential and a
boolean indicating whether the credential was exchanged.
"""
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
if not client:
logger.warning(
"Could not create OAuth2 session for authorization code exchange"
)
return auth_credential
return ExchangeResult(auth_credential, False)
try:
tokens = client.fetch_token(
@@ -202,6 +206,6 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
logger.debug("Successfully exchanged authorization code for access token")
except Exception as e:
logger.error("Failed to exchange authorization code: %s", e)
return auth_credential
return ExchangeResult(auth_credential, False)
return auth_credential
return ExchangeResult(auth_credential, True)
@@ -33,7 +33,6 @@ import pytest
class TestOAuth2CredentialExchanger:
"""Test suite for OAuth2CredentialExchanger."""
@pytest.mark.asyncio
async def test_exchange_with_existing_token(self):
"""Test exchange method when access token already exists."""
scheme = OpenIdConnectWithConfig(
@@ -55,14 +54,14 @@ class TestOAuth2CredentialExchanger:
)
exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Should return the same credential since access token already exists
assert result == credential
assert result.oauth2.access_token == "existing_token"
assert exchange_result.credential == credential
assert exchange_result.credential.oauth2.access_token == "existing_token"
assert not exchange_result.was_exchanged
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_success(self, mock_oauth2_session):
"""Test successful token exchange."""
# Setup mock
@@ -96,14 +95,16 @@ class TestOAuth2CredentialExchanger:
)
exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Verify token exchange was successful
assert result.oauth2.access_token == "new_access_token"
assert result.oauth2.refresh_token == "new_refresh_token"
assert exchange_result.credential.oauth2.access_token == "new_access_token"
assert (
exchange_result.credential.oauth2.refresh_token == "new_refresh_token"
)
assert exchange_result.was_exchanged
mock_client.fetch_token.assert_called_once()
@pytest.mark.asyncio
async def test_exchange_missing_auth_scheme(self):
"""Test exchange with missing auth_scheme raises ValueError."""
credential = AuthCredential(
@@ -122,7 +123,6 @@ class TestOAuth2CredentialExchanger:
assert "auth_scheme is required" in str(e)
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_no_session(self, mock_oauth2_session):
"""Test exchange when OAuth2Session cannot be created."""
# Mock to return None for create_oauth2_session
@@ -146,14 +146,14 @@ class TestOAuth2CredentialExchanger:
)
exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Should return original credential when session creation fails
assert result == credential
assert result.oauth2.access_token is None
assert exchange_result.credential == credential
assert exchange_result.credential.oauth2.access_token is None
assert not exchange_result.was_exchanged
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
"""Test exchange when fetch_token fails."""
# Setup mock to raise exception during fetch_token
@@ -181,14 +181,14 @@ class TestOAuth2CredentialExchanger:
)
exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Should return original credential when fetch_token fails
assert result == credential
assert result.oauth2.access_token is None
assert exchange_result.credential == credential
assert exchange_result.credential.oauth2.access_token is None
assert not exchange_result.was_exchanged
mock_client.fetch_token.assert_called_once()
@pytest.mark.asyncio
async def test_exchange_authlib_not_available(self):
"""Test exchange when authlib is not available."""
scheme = OpenIdConnectWithConfig(
@@ -217,14 +217,14 @@ class TestOAuth2CredentialExchanger:
"google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE",
False,
):
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Should return original credential when authlib is not available
assert result == credential
assert result.oauth2.access_token is None
assert exchange_result.credential == credential
assert exchange_result.credential.oauth2.access_token is None
assert not exchange_result.was_exchanged
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_client_credentials_success(self, mock_oauth2_session):
"""Test successful client credentials exchange."""
# Setup mock
@@ -255,17 +255,19 @@ class TestOAuth2CredentialExchanger:
)
exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Verify client credentials exchange was successful
assert result.oauth2.access_token == "client_access_token"
assert (
exchange_result.credential.oauth2.access_token == "client_access_token"
)
assert exchange_result.was_exchanged
mock_client.fetch_token.assert_called_once_with(
"https://example.com/token",
grant_type="client_credentials",
)
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
"""Test client credentials exchange failure."""
# Setup mock to raise exception during fetch_token
@@ -292,15 +294,15 @@ class TestOAuth2CredentialExchanger:
)
exchanger = OAuth2CredentialExchanger()
result = await exchanger.exchange(credential, scheme)
exchange_result = await exchanger.exchange(credential, scheme)
# Should return original credential when client credentials exchange fails
assert result == credential
assert result.oauth2.access_token is None
assert exchange_result.credential == credential
assert exchange_result.credential.oauth2.access_token is None
assert not exchange_result.was_exchanged
mock_client.fetch_token.assert_called_once()
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
@pytest.mark.asyncio
async def test_exchange_normalize_uri(self, mock_oauth2_session):
"""Test exchange method normalizes auth_response_uri."""
mock_client = Mock()
@@ -344,7 +346,6 @@ class TestOAuth2CredentialExchanger:
client_id="test_client_id",
)
@pytest.mark.asyncio
async def test_determine_grant_type_client_credentials(self):
"""Test grant type determination for client credentials."""
flows = OAuthFlows(
@@ -361,7 +362,6 @@ class TestOAuth2CredentialExchanger:
assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS
@pytest.mark.asyncio
async def test_determine_grant_type_openid_connect(self):
"""Test grant type determination for OpenID Connect (defaults to auth code)."""
scheme = OpenIdConnectWithConfig(