You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: oauth refresh not triggered on token expiry
Merge https://github.com/google/adk-python/pull/3767 Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3767 from MarlzRana:marlzrana/fix-oauth-refresh-not-triggered-on-token-expiry 2dae3917de6e8d3fa5317857d06305d47c0773b0 PiperOrigin-RevId: 843756363
This commit is contained in:
committed by
Copybara-Service
parent
5c4bae7ff2
commit
69997cd5ef
@@ -48,9 +48,10 @@ class AuthHandler:
|
||||
self,
|
||||
) -> AuthCredential:
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
return await exchanger.exchange(
|
||||
exchange_result = await exchanger.exchange(
|
||||
self.auth_config.exchanged_auth_credential, self.auth_config.auth_scheme
|
||||
)
|
||||
return exchange_result.credential
|
||||
|
||||
async def parse_and_store_auth_response(self, state: State) -> None:
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from .auth_schemes import ExtendedOAuth2
|
||||
from .auth_schemes import OpenIdConnectWithConfig
|
||||
from .auth_tool import AuthConfig
|
||||
from .exchanger.base_credential_exchanger import BaseCredentialExchanger
|
||||
from .exchanger.base_credential_exchanger import ExchangeResult
|
||||
from .exchanger.credential_exchanger_registry import CredentialExchangerRegistry
|
||||
from .oauth2_discovery import OAuth2DiscoveryManager
|
||||
from .refresher.credential_refresher_registry import CredentialRefresherRegistry
|
||||
@@ -214,15 +215,17 @@ class CredentialManager:
|
||||
return credential, False
|
||||
|
||||
if isinstance(exchanger, ServiceAccountCredentialExchanger):
|
||||
exchanged_credential = exchanger.exchange_credential(
|
||||
self._auth_config.auth_scheme, credential
|
||||
)
|
||||
else:
|
||||
exchanged_credential = await exchanger.exchange(
|
||||
credential, self._auth_config.auth_scheme
|
||||
return (
|
||||
exchanger.exchange_credential(
|
||||
self._auth_config.auth_scheme, credential
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
return exchanged_credential, True
|
||||
exchange_result = await exchanger.exchange(
|
||||
credential, self._auth_config.auth_scheme
|
||||
)
|
||||
return exchange_result.credential, exchange_result.was_exchanged
|
||||
|
||||
async def _refresh_credential(
|
||||
self, credential: AuthCredential
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
|
||||
from ...utils.feature_decorator import experimental
|
||||
@@ -28,6 +29,11 @@ class CredentialExchangeError(Exception):
|
||||
"""Base exception for credential exchange errors."""
|
||||
|
||||
|
||||
class ExchangeResult(NamedTuple):
|
||||
credential: AuthCredential
|
||||
was_exchanged: bool
|
||||
|
||||
|
||||
@experimental
|
||||
class BaseCredentialExchanger(abc.ABC):
|
||||
"""Base interface for credential exchangers.
|
||||
@@ -41,15 +47,17 @@ class BaseCredentialExchanger(abc.ABC):
|
||||
self,
|
||||
auth_credential: AuthCredential,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
) -> AuthCredential:
|
||||
) -> ExchangeResult:
|
||||
"""Exchange credential if needed.
|
||||
|
||||
Args:
|
||||
auth_credential: The credential to exchange.
|
||||
auth_scheme: The authentication scheme (optional, some exchangers don't need it).
|
||||
auth_scheme: The authentication scheme (optional, some exchangers don't
|
||||
need it).
|
||||
|
||||
Returns:
|
||||
The exchanged credential.
|
||||
An ExchangeResult object containing the exchanged credential and a
|
||||
boolean indicating whether the credential was exchanged.
|
||||
|
||||
Raises:
|
||||
CredentialExchangeError: If credential exchange fails.
|
||||
|
||||
@@ -31,6 +31,7 @@ from typing_extensions import override
|
||||
|
||||
from .base_credential_exchanger import BaseCredentialExchanger
|
||||
from .base_credential_exchanger import CredentialExchangeError
|
||||
from .base_credential_exchanger import ExchangeResult
|
||||
|
||||
try:
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
@@ -51,7 +52,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
self,
|
||||
auth_credential: AuthCredential,
|
||||
auth_scheme: Optional[AuthScheme] = None,
|
||||
) -> AuthCredential:
|
||||
) -> ExchangeResult:
|
||||
"""Exchange OAuth2 credential from authorization response.
|
||||
|
||||
if credential exchange failed, the original credential will be returned.
|
||||
@@ -61,7 +62,8 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
auth_scheme: The OAuth2 authentication scheme.
|
||||
|
||||
Returns:
|
||||
The exchanged credential with access token.
|
||||
An ExchangeResult object containing the exchanged credential and a
|
||||
boolean indicating whether the credential was exchanged.
|
||||
|
||||
Raises:
|
||||
CredentialExchangeError: If auth_scheme is missing.
|
||||
@@ -79,10 +81,10 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
logger.warning(
|
||||
"authlib is not available, skipping OAuth2 credential exchange."
|
||||
)
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
if auth_credential.oauth2 and auth_credential.oauth2.access_token:
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
# Determine grant type from auth_scheme
|
||||
grant_type = self._determine_grant_type(auth_scheme)
|
||||
@@ -97,7 +99,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
)
|
||||
else:
|
||||
logger.warning("Unsupported OAuth2 grant type: %s", grant_type)
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
def _determine_grant_type(
|
||||
self, auth_scheme: AuthScheme
|
||||
@@ -129,7 +131,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
self,
|
||||
auth_credential: AuthCredential,
|
||||
auth_scheme: AuthScheme,
|
||||
) -> AuthCredential:
|
||||
) -> ExchangeResult:
|
||||
"""Exchange client credentials for access token.
|
||||
|
||||
Args:
|
||||
@@ -137,14 +139,15 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
auth_scheme: The OAuth2 authentication scheme.
|
||||
|
||||
Returns:
|
||||
The credential with access token.
|
||||
An ExchangeResult object containing the exchanged credential and a
|
||||
boolean indicating whether the credential was exchanged.
|
||||
"""
|
||||
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
|
||||
if not client:
|
||||
logger.warning(
|
||||
"Could not create OAuth2 session for client credentials exchange"
|
||||
)
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
try:
|
||||
tokens = client.fetch_token(
|
||||
@@ -155,13 +158,13 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
logger.debug("Successfully exchanged client credentials for access token")
|
||||
except Exception as e:
|
||||
logger.error("Failed to exchange client credentials: %s", e)
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, True)
|
||||
|
||||
def _normalize_auth_uri(self, auth_uri: str | None) -> str | None:
|
||||
# Authlib currently used a simplified token check by simply scanning hash existence,
|
||||
# yet itself might sometimes add extraneous hashes.
|
||||
# Authlib currently used a simplified token check by simply scanning hash
|
||||
# existence, yet itself might sometimes add extraneous hashes.
|
||||
# Drop trailing empty hash if seen.
|
||||
if auth_uri and auth_uri.endswith("#"):
|
||||
return auth_uri[:-1]
|
||||
@@ -171,7 +174,7 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
self,
|
||||
auth_credential: AuthCredential,
|
||||
auth_scheme: AuthScheme,
|
||||
) -> AuthCredential:
|
||||
) -> ExchangeResult:
|
||||
"""Exchange authorization code for access token.
|
||||
|
||||
Args:
|
||||
@@ -179,14 +182,15 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
auth_scheme: The OAuth2 authentication scheme.
|
||||
|
||||
Returns:
|
||||
The credential with access token.
|
||||
An ExchangeResult object containing the exchanged credential and a
|
||||
boolean indicating whether the credential was exchanged.
|
||||
"""
|
||||
client, token_endpoint = create_oauth2_session(auth_scheme, auth_credential)
|
||||
if not client:
|
||||
logger.warning(
|
||||
"Could not create OAuth2 session for authorization code exchange"
|
||||
)
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
try:
|
||||
tokens = client.fetch_token(
|
||||
@@ -202,6 +206,6 @@ class OAuth2CredentialExchanger(BaseCredentialExchanger):
|
||||
logger.debug("Successfully exchanged authorization code for access token")
|
||||
except Exception as e:
|
||||
logger.error("Failed to exchange authorization code: %s", e)
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, False)
|
||||
|
||||
return auth_credential
|
||||
return ExchangeResult(auth_credential, True)
|
||||
|
||||
@@ -33,7 +33,6 @@ import pytest
|
||||
class TestOAuth2CredentialExchanger:
|
||||
"""Test suite for OAuth2CredentialExchanger."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_with_existing_token(self):
|
||||
"""Test exchange method when access token already exists."""
|
||||
scheme = OpenIdConnectWithConfig(
|
||||
@@ -55,14 +54,14 @@ class TestOAuth2CredentialExchanger:
|
||||
)
|
||||
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Should return the same credential since access token already exists
|
||||
assert result == credential
|
||||
assert result.oauth2.access_token == "existing_token"
|
||||
assert exchange_result.credential == credential
|
||||
assert exchange_result.credential.oauth2.access_token == "existing_token"
|
||||
assert not exchange_result.was_exchanged
|
||||
|
||||
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_success(self, mock_oauth2_session):
|
||||
"""Test successful token exchange."""
|
||||
# Setup mock
|
||||
@@ -96,14 +95,16 @@ class TestOAuth2CredentialExchanger:
|
||||
)
|
||||
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Verify token exchange was successful
|
||||
assert result.oauth2.access_token == "new_access_token"
|
||||
assert result.oauth2.refresh_token == "new_refresh_token"
|
||||
assert exchange_result.credential.oauth2.access_token == "new_access_token"
|
||||
assert (
|
||||
exchange_result.credential.oauth2.refresh_token == "new_refresh_token"
|
||||
)
|
||||
assert exchange_result.was_exchanged
|
||||
mock_client.fetch_token.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_missing_auth_scheme(self):
|
||||
"""Test exchange with missing auth_scheme raises ValueError."""
|
||||
credential = AuthCredential(
|
||||
@@ -122,7 +123,6 @@ class TestOAuth2CredentialExchanger:
|
||||
assert "auth_scheme is required" in str(e)
|
||||
|
||||
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_no_session(self, mock_oauth2_session):
|
||||
"""Test exchange when OAuth2Session cannot be created."""
|
||||
# Mock to return None for create_oauth2_session
|
||||
@@ -146,14 +146,14 @@ class TestOAuth2CredentialExchanger:
|
||||
)
|
||||
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Should return original credential when session creation fails
|
||||
assert result == credential
|
||||
assert result.oauth2.access_token is None
|
||||
assert exchange_result.credential == credential
|
||||
assert exchange_result.credential.oauth2.access_token is None
|
||||
assert not exchange_result.was_exchanged
|
||||
|
||||
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_fetch_token_failure(self, mock_oauth2_session):
|
||||
"""Test exchange when fetch_token fails."""
|
||||
# Setup mock to raise exception during fetch_token
|
||||
@@ -181,14 +181,14 @@ class TestOAuth2CredentialExchanger:
|
||||
)
|
||||
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Should return original credential when fetch_token fails
|
||||
assert result == credential
|
||||
assert result.oauth2.access_token is None
|
||||
assert exchange_result.credential == credential
|
||||
assert exchange_result.credential.oauth2.access_token is None
|
||||
assert not exchange_result.was_exchanged
|
||||
mock_client.fetch_token.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_authlib_not_available(self):
|
||||
"""Test exchange when authlib is not available."""
|
||||
scheme = OpenIdConnectWithConfig(
|
||||
@@ -217,14 +217,14 @@ class TestOAuth2CredentialExchanger:
|
||||
"google.adk.auth.exchanger.oauth2_credential_exchanger.AUTHLIB_AVAILABLE",
|
||||
False,
|
||||
):
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Should return original credential when authlib is not available
|
||||
assert result == credential
|
||||
assert result.oauth2.access_token is None
|
||||
assert exchange_result.credential == credential
|
||||
assert exchange_result.credential.oauth2.access_token is None
|
||||
assert not exchange_result.was_exchanged
|
||||
|
||||
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_client_credentials_success(self, mock_oauth2_session):
|
||||
"""Test successful client credentials exchange."""
|
||||
# Setup mock
|
||||
@@ -255,17 +255,19 @@ class TestOAuth2CredentialExchanger:
|
||||
)
|
||||
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Verify client credentials exchange was successful
|
||||
assert result.oauth2.access_token == "client_access_token"
|
||||
assert (
|
||||
exchange_result.credential.oauth2.access_token == "client_access_token"
|
||||
)
|
||||
assert exchange_result.was_exchanged
|
||||
mock_client.fetch_token.assert_called_once_with(
|
||||
"https://example.com/token",
|
||||
grant_type="client_credentials",
|
||||
)
|
||||
|
||||
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_client_credentials_failure(self, mock_oauth2_session):
|
||||
"""Test client credentials exchange failure."""
|
||||
# Setup mock to raise exception during fetch_token
|
||||
@@ -292,15 +294,15 @@ class TestOAuth2CredentialExchanger:
|
||||
)
|
||||
|
||||
exchanger = OAuth2CredentialExchanger()
|
||||
result = await exchanger.exchange(credential, scheme)
|
||||
exchange_result = await exchanger.exchange(credential, scheme)
|
||||
|
||||
# Should return original credential when client credentials exchange fails
|
||||
assert result == credential
|
||||
assert result.oauth2.access_token is None
|
||||
assert exchange_result.credential == credential
|
||||
assert exchange_result.credential.oauth2.access_token is None
|
||||
assert not exchange_result.was_exchanged
|
||||
mock_client.fetch_token.assert_called_once()
|
||||
|
||||
@patch("google.adk.auth.oauth2_credential_util.OAuth2Session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_normalize_uri(self, mock_oauth2_session):
|
||||
"""Test exchange method normalizes auth_response_uri."""
|
||||
mock_client = Mock()
|
||||
@@ -344,7 +346,6 @@ class TestOAuth2CredentialExchanger:
|
||||
client_id="test_client_id",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_grant_type_client_credentials(self):
|
||||
"""Test grant type determination for client credentials."""
|
||||
flows = OAuthFlows(
|
||||
@@ -361,7 +362,6 @@ class TestOAuth2CredentialExchanger:
|
||||
|
||||
assert grant_type == OAuthGrantType.CLIENT_CREDENTIALS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_grant_type_openid_connect(self):
|
||||
"""Test grant type determination for OpenID Connect (defaults to auth code)."""
|
||||
scheme = OpenIdConnectWithConfig(
|
||||
|
||||
Reference in New Issue
Block a user