You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix!: Make credential manager to accept tool_context instead of callback_context
This seems a breaking change, but actually credential manager is used internally only and also it won't work if some one call it using callback context Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 863534476
This commit is contained in:
committed by
Copybara-Service
parent
798d0053c8
commit
fe82f3cde8
@@ -19,8 +19,8 @@ from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
|
||||
from ..agents.callback_context import CallbackContext
|
||||
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
|
||||
from ..tools.tool_context import ToolContext
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .auth_credential import AuthCredential
|
||||
from .auth_credential import AuthCredentialTypes
|
||||
@@ -72,7 +72,7 @@ class CredentialManager:
|
||||
)
|
||||
|
||||
# Load and prepare credential
|
||||
credential = await manager.load_auth_credential(callback_context)
|
||||
credential = await manager.load_auth_credential(tool_context)
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -124,11 +124,11 @@ class CredentialManager:
|
||||
"""
|
||||
self._exchanger_registry.register(credential_type, exchanger_instance)
|
||||
|
||||
async def request_credential(self, callback_context: CallbackContext) -> None:
|
||||
callback_context.request_credential(self._auth_config)
|
||||
async def request_credential(self, tool_context: ToolContext) -> None:
|
||||
tool_context.request_credential(self._auth_config)
|
||||
|
||||
async def get_auth_credential(
|
||||
self, callback_context: CallbackContext
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Load and prepare authentication credential through a structured workflow."""
|
||||
|
||||
@@ -140,14 +140,14 @@ class CredentialManager:
|
||||
return self._auth_config.raw_auth_credential
|
||||
|
||||
# Step 3: Try to load existing processed credential
|
||||
credential = await self._load_existing_credential(callback_context)
|
||||
credential = await self._load_existing_credential(tool_context)
|
||||
|
||||
# Step 4: If no existing credential, load from auth response
|
||||
# TODO instead of load from auth response, we can store auth response in
|
||||
# credential service.
|
||||
was_from_auth_response = False
|
||||
if not credential:
|
||||
credential = await self._load_from_auth_response(callback_context)
|
||||
credential = await self._load_from_auth_response(tool_context)
|
||||
was_from_auth_response = True
|
||||
|
||||
# Step 5: If still no credential available, check if client credentials
|
||||
@@ -169,38 +169,38 @@ class CredentialManager:
|
||||
|
||||
# Step 8: Save credential if it was modified
|
||||
if was_from_auth_response or was_exchanged or was_refreshed:
|
||||
await self._save_credential(callback_context, credential)
|
||||
await self._save_credential(tool_context, credential)
|
||||
|
||||
return credential
|
||||
|
||||
async def _load_existing_credential(
|
||||
self, callback_context: CallbackContext
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Load existing credential from credential service."""
|
||||
|
||||
# Try loading from credential service first
|
||||
credential = await self._load_from_credential_service(callback_context)
|
||||
credential = await self._load_from_credential_service(tool_context)
|
||||
if credential:
|
||||
return credential
|
||||
|
||||
return None
|
||||
|
||||
async def _load_from_credential_service(
|
||||
self, callback_context: CallbackContext
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Load credential from credential service if available."""
|
||||
credential_service = callback_context._invocation_context.credential_service
|
||||
credential_service = tool_context._invocation_context.credential_service
|
||||
if credential_service:
|
||||
# Note: This should be made async in a future refactor
|
||||
# For now, assuming synchronous operation
|
||||
return await callback_context.load_credential(self._auth_config)
|
||||
return await tool_context.load_credential(self._auth_config)
|
||||
return None
|
||||
|
||||
async def _load_from_auth_response(
|
||||
self, callback_context: CallbackContext
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Load credential from auth response in callback context."""
|
||||
return callback_context.get_auth_response(self._auth_config)
|
||||
"""Load credential from auth response in tool context."""
|
||||
return tool_context.get_auth_response(self._auth_config)
|
||||
|
||||
async def _exchange_credential(
|
||||
self, credential: AuthCredential
|
||||
@@ -290,15 +290,15 @@ class CredentialManager:
|
||||
# Additional validation can be added here
|
||||
|
||||
async def _save_credential(
|
||||
self, callback_context: CallbackContext, credential: AuthCredential
|
||||
self, tool_context: ToolContext, credential: AuthCredential
|
||||
) -> None:
|
||||
"""Save credential to credential service if available."""
|
||||
# Update the exchanged credential in config
|
||||
self._auth_config.exchanged_auth_credential = credential
|
||||
|
||||
credential_service = callback_context._invocation_context.credential_service
|
||||
credential_service = tool_context._invocation_context.credential_service
|
||||
if credential_service:
|
||||
await callback_context.save_credential(self._auth_config)
|
||||
await tool_context.save_credential(self._auth_config)
|
||||
|
||||
async def _populate_auth_scheme(self) -> bool:
|
||||
"""Auto-discover server metadata and populate missing auth scheme info.
|
||||
|
||||
@@ -49,13 +49,13 @@ class TestCredentialManager:
|
||||
async def test_request_credential(self):
|
||||
"""Test request_credential method."""
|
||||
auth_config = Mock(spec=AuthConfig)
|
||||
callback_context = Mock()
|
||||
callback_context.request_credential = Mock()
|
||||
tool_context = Mock()
|
||||
tool_context.request_credential = Mock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
await manager.request_credential(callback_context)
|
||||
await manager.request_credential(tool_context)
|
||||
|
||||
callback_context.request_credential.assert_called_once_with(auth_config)
|
||||
tool_context.request_credential.assert_called_once_with(auth_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_auth_credentials_success(self):
|
||||
@@ -69,7 +69,7 @@ class TestCredentialManager:
|
||||
mock_credential = Mock(spec=AuthCredential)
|
||||
mock_credential.auth_type = AuthCredentialTypes.API_KEY
|
||||
|
||||
callback_context = Mock()
|
||||
tool_context = Mock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
|
||||
@@ -86,17 +86,17 @@ class TestCredentialManager:
|
||||
)
|
||||
manager._save_credential = AsyncMock()
|
||||
|
||||
result = await manager.get_auth_credential(callback_context)
|
||||
result = await manager.get_auth_credential(tool_context)
|
||||
|
||||
# Verify all methods were called
|
||||
manager._validate_credential.assert_called_once()
|
||||
manager._is_credential_ready.assert_called_once()
|
||||
manager._load_existing_credential.assert_called_once_with(callback_context)
|
||||
manager._load_from_auth_response.assert_called_once_with(callback_context)
|
||||
manager._load_existing_credential.assert_called_once_with(tool_context)
|
||||
manager._load_from_auth_response.assert_called_once_with(tool_context)
|
||||
manager._exchange_credential.assert_called_once_with(mock_credential)
|
||||
manager._refresh_credential.assert_called_once_with(mock_credential)
|
||||
manager._save_credential.assert_called_once_with(
|
||||
callback_context, mock_credential
|
||||
tool_context, mock_credential
|
||||
)
|
||||
|
||||
assert result == mock_credential
|
||||
@@ -111,7 +111,7 @@ class TestCredentialManager:
|
||||
auth_config.auth_scheme = Mock()
|
||||
auth_config.auth_scheme.flows = None
|
||||
|
||||
callback_context = Mock()
|
||||
tool_context = Mock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
|
||||
@@ -121,13 +121,13 @@ class TestCredentialManager:
|
||||
manager._load_existing_credential = AsyncMock(return_value=None)
|
||||
manager._load_from_auth_response = AsyncMock(return_value=None)
|
||||
|
||||
result = await manager.get_auth_credential(callback_context)
|
||||
result = await manager.get_auth_credential(tool_context)
|
||||
|
||||
# Verify methods were called but no credential returned
|
||||
manager._validate_credential.assert_called_once()
|
||||
manager._is_credential_ready.assert_called_once()
|
||||
manager._load_existing_credential.assert_called_once_with(callback_context)
|
||||
manager._load_from_auth_response.assert_called_once_with(callback_context)
|
||||
manager._load_existing_credential.assert_called_once_with(tool_context)
|
||||
manager._load_from_auth_response.assert_called_once_with(tool_context)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -138,12 +138,12 @@ class TestCredentialManager:
|
||||
mock_credential = Mock(spec=AuthCredential)
|
||||
auth_config.exchanged_auth_credential = mock_credential
|
||||
|
||||
callback_context = Mock()
|
||||
tool_context = Mock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
manager._load_from_credential_service = AsyncMock(return_value=None)
|
||||
|
||||
result = await manager._load_existing_credential(callback_context)
|
||||
result = await manager._load_existing_credential(tool_context)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -155,23 +155,21 @@ class TestCredentialManager:
|
||||
|
||||
mock_credential = Mock(spec=AuthCredential)
|
||||
|
||||
callback_context = Mock()
|
||||
tool_context = Mock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
manager._load_from_credential_service = AsyncMock(
|
||||
return_value=mock_credential
|
||||
)
|
||||
|
||||
result = await manager._load_existing_credential(callback_context)
|
||||
result = await manager._load_existing_credential(tool_context)
|
||||
|
||||
manager._load_from_credential_service.assert_called_once_with(
|
||||
callback_context
|
||||
)
|
||||
manager._load_from_credential_service.assert_called_once_with(tool_context)
|
||||
assert result == mock_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_from_credential_service_with_service(self):
|
||||
"""Test _load_from_credential_service from callback context when credential service is available."""
|
||||
"""Test _load_from_credential_service from tool context when credential service is available."""
|
||||
auth_config = Mock(spec=AuthConfig)
|
||||
|
||||
mock_credential = Mock(spec=AuthCredential)
|
||||
@@ -183,14 +181,14 @@ class TestCredentialManager:
|
||||
invocation_context = Mock()
|
||||
invocation_context.credential_service = credential_service
|
||||
|
||||
callback_context = Mock()
|
||||
callback_context._invocation_context = invocation_context
|
||||
callback_context.load_credential = AsyncMock(return_value=mock_credential)
|
||||
tool_context = Mock()
|
||||
tool_context._invocation_context = invocation_context
|
||||
tool_context.load_credential = AsyncMock(return_value=mock_credential)
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
result = await manager._load_from_credential_service(callback_context)
|
||||
result = await manager._load_from_credential_service(tool_context)
|
||||
|
||||
callback_context.load_credential.assert_called_once_with(auth_config)
|
||||
tool_context.load_credential.assert_called_once_with(auth_config)
|
||||
assert result == mock_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -202,11 +200,11 @@ class TestCredentialManager:
|
||||
invocation_context = Mock()
|
||||
invocation_context.credential_service = None
|
||||
|
||||
callback_context = Mock()
|
||||
callback_context._invocation_context = invocation_context
|
||||
tool_context = Mock()
|
||||
tool_context._invocation_context = invocation_context
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
result = await manager._load_from_credential_service(callback_context)
|
||||
result = await manager._load_from_credential_service(tool_context)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -223,14 +221,14 @@ class TestCredentialManager:
|
||||
invocation_context = Mock()
|
||||
invocation_context.credential_service = credential_service
|
||||
|
||||
callback_context = Mock()
|
||||
callback_context._invocation_context = invocation_context
|
||||
callback_context.save_credential = AsyncMock()
|
||||
tool_context = Mock()
|
||||
tool_context._invocation_context = invocation_context
|
||||
tool_context.save_credential = AsyncMock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
await manager._save_credential(callback_context, mock_credential)
|
||||
await manager._save_credential(tool_context, mock_credential)
|
||||
|
||||
callback_context.save_credential.assert_called_once_with(auth_config)
|
||||
tool_context.save_credential.assert_called_once_with(auth_config)
|
||||
assert auth_config.exchanged_auth_credential == mock_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -244,11 +242,11 @@ class TestCredentialManager:
|
||||
invocation_context = Mock()
|
||||
invocation_context.credential_service = None
|
||||
|
||||
callback_context = Mock()
|
||||
callback_context._invocation_context = invocation_context
|
||||
tool_context = Mock()
|
||||
tool_context._invocation_context = invocation_context
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
await manager._save_credential(callback_context, mock_credential)
|
||||
await manager._save_credential(tool_context, mock_credential)
|
||||
|
||||
# Should not raise an error, and credential should be set in auth_config
|
||||
# even when there's no credential service (config is updated regardless)
|
||||
|
||||
Reference in New Issue
Block a user