You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
refactor: Add save_credential and load_credential in callback context for developer to access credential service
developer may want to save/load credentials themselves to/from credential service. see (https://github.com/google/adk-python/issues/1816) PiperOrigin-RevId: 783487628
This commit is contained in:
committed by
Copybara-Service
parent
b1fa383e73
commit
b977d12ea8
@@ -24,6 +24,8 @@ from .readonly_context import ReadonlyContext
|
||||
if TYPE_CHECKING:
|
||||
from google.genai import types
|
||||
|
||||
from ..auth.auth_credential import AuthCredential
|
||||
from ..auth.auth_tool import AuthConfig
|
||||
from ..events.event_actions import EventActions
|
||||
from ..sessions.state import State
|
||||
from .invocation_context import InvocationContext
|
||||
@@ -115,3 +117,32 @@ class CallbackContext(ReadonlyContext):
|
||||
user_id=self._invocation_context.user_id,
|
||||
session_id=self._invocation_context.session.id,
|
||||
)
|
||||
|
||||
async def save_credential(self, auth_config: AuthConfig) -> None:
|
||||
"""Saves a credential to the credential service.
|
||||
|
||||
Args:
|
||||
auth_config: The authentication configuration containing the credential.
|
||||
"""
|
||||
if self._invocation_context.credential_service is None:
|
||||
raise ValueError("Credential service is not initialized.")
|
||||
await self._invocation_context.credential_service.save_credential(
|
||||
auth_config, self
|
||||
)
|
||||
|
||||
async def load_credential(
|
||||
self, auth_config: AuthConfig
|
||||
) -> Optional[AuthCredential]:
|
||||
"""Loads a credential from the credential service.
|
||||
|
||||
Args:
|
||||
auth_config: The authentication configuration for the credential.
|
||||
|
||||
Returns:
|
||||
The loaded credential, or None if not found.
|
||||
"""
|
||||
if self._invocation_context.credential_service is None:
|
||||
raise ValueError("Credential service is not initialized.")
|
||||
return await self._invocation_context.credential_service.load_credential(
|
||||
auth_config, self
|
||||
)
|
||||
|
||||
@@ -168,9 +168,7 @@ class CredentialManager:
|
||||
if credential_service:
|
||||
# Note: This should be made async in a future refactor
|
||||
# For now, assuming synchronous operation
|
||||
return await credential_service.load_credential(
|
||||
self._auth_config, callback_context
|
||||
)
|
||||
return await callback_context.load_credential(self._auth_config)
|
||||
return None
|
||||
|
||||
async def _load_from_auth_response(
|
||||
@@ -255,10 +253,9 @@ class CredentialManager:
|
||||
self, callback_context: CallbackContext, credential: AuthCredential
|
||||
) -> None:
|
||||
"""Save credential to credential service if available."""
|
||||
# Update the exchanged credential in config
|
||||
self._auth_config.exchanged_auth_credential = credential
|
||||
|
||||
credential_service = callback_context._invocation_context.credential_service
|
||||
if credential_service:
|
||||
# Update the exchanged credential in config
|
||||
self._auth_config.exchanged_auth_credential = credential
|
||||
await credential_service.save_credential(
|
||||
self._auth_config, callback_context
|
||||
)
|
||||
await callback_context.save_credential(self._auth_config)
|
||||
|
||||
@@ -16,9 +16,14 @@
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_tool import AuthConfig
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -32,6 +37,8 @@ def mock_invocation_context():
|
||||
mock_context.session.id = "test-session-id"
|
||||
mock_context.app_name = "test-app"
|
||||
mock_context.user_id = "test-user"
|
||||
mock_context.artifact_service = None
|
||||
mock_context.credential_service = None
|
||||
return mock_context
|
||||
|
||||
|
||||
@@ -63,6 +70,21 @@ def callback_context_without_artifact_service(mock_invocation_context):
|
||||
return CallbackContext(mock_invocation_context)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_config():
|
||||
"""Create a mock auth config for testing."""
|
||||
mock_config = Mock(spec=AuthConfig)
|
||||
return mock_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_credential():
|
||||
"""Create a mock auth credential for testing."""
|
||||
mock_credential = Mock(spec=AuthCredential)
|
||||
mock_credential.auth_type = AuthCredentialTypes.OAUTH2
|
||||
return mock_credential
|
||||
|
||||
|
||||
class TestCallbackContextListArtifacts:
|
||||
"""Test the list_artifacts method in CallbackContext."""
|
||||
|
||||
@@ -119,8 +141,8 @@ class TestCallbackContextListArtifacts:
|
||||
await callback_context_with_artifact_service.list_artifacts()
|
||||
|
||||
|
||||
class TestToolContextListArtifacts:
|
||||
"""Test that list_artifacts is available in ToolContext through inheritance."""
|
||||
class TestCallbackContext:
|
||||
"""Test suite for CallbackContext."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_context_inherits_list_artifacts(
|
||||
@@ -167,3 +189,134 @@ class TestToolContextListArtifacts:
|
||||
):
|
||||
"""Test that ToolContext and CallbackContext share the same list_artifacts method."""
|
||||
assert ToolContext.list_artifacts is CallbackContext.list_artifacts
|
||||
|
||||
def test_initialization(self, mock_invocation_context):
|
||||
"""Test CallbackContext initialization."""
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
assert context._invocation_context == mock_invocation_context
|
||||
assert context._event_actions is not None
|
||||
assert context._state is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_credential_with_service(
|
||||
self, mock_invocation_context, mock_auth_config
|
||||
):
|
||||
"""Test save_credential when credential service is available."""
|
||||
# Mock credential service
|
||||
credential_service = AsyncMock()
|
||||
mock_invocation_context.credential_service = credential_service
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
await context.save_credential(mock_auth_config)
|
||||
|
||||
credential_service.save_credential.assert_called_once_with(
|
||||
mock_auth_config, context
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_credential_no_service(
|
||||
self, mock_invocation_context, mock_auth_config
|
||||
):
|
||||
"""Test save_credential when credential service is not available."""
|
||||
mock_invocation_context.credential_service = None
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Credential service is not initialized"
|
||||
):
|
||||
await context.save_credential(mock_auth_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_credential_with_service(
|
||||
self, mock_invocation_context, mock_auth_config, mock_auth_credential
|
||||
):
|
||||
"""Test load_credential when credential service is available."""
|
||||
# Mock credential service
|
||||
credential_service = AsyncMock()
|
||||
credential_service.load_credential.return_value = mock_auth_credential
|
||||
mock_invocation_context.credential_service = credential_service
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
result = await context.load_credential(mock_auth_config)
|
||||
|
||||
credential_service.load_credential.assert_called_once_with(
|
||||
mock_auth_config, context
|
||||
)
|
||||
assert result == mock_auth_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_credential_no_service(
|
||||
self, mock_invocation_context, mock_auth_config
|
||||
):
|
||||
"""Test load_credential when credential service is not available."""
|
||||
mock_invocation_context.credential_service = None
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Credential service is not initialized"
|
||||
):
|
||||
await context.load_credential(mock_auth_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_credential_returns_none(
|
||||
self, mock_invocation_context, mock_auth_config
|
||||
):
|
||||
"""Test load_credential returns None when credential not found."""
|
||||
# Mock credential service
|
||||
credential_service = AsyncMock()
|
||||
credential_service.load_credential.return_value = None
|
||||
mock_invocation_context.credential_service = credential_service
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
result = await context.load_credential(mock_auth_config)
|
||||
|
||||
credential_service.load_credential.assert_called_once_with(
|
||||
mock_auth_config, context
|
||||
)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_artifact_integration(self, mock_invocation_context):
|
||||
"""Test save_artifact to ensure credential methods follow same pattern."""
|
||||
# Mock artifact service
|
||||
artifact_service = AsyncMock()
|
||||
artifact_service.save_artifact.return_value = 1
|
||||
mock_invocation_context.artifact_service = artifact_service
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
test_artifact = Part.from_text(text="test content")
|
||||
|
||||
version = await context.save_artifact("test_file.txt", test_artifact)
|
||||
|
||||
artifact_service.save_artifact.assert_called_once_with(
|
||||
app_name="test-app",
|
||||
user_id="test-user",
|
||||
session_id="test-session-id",
|
||||
filename="test_file.txt",
|
||||
artifact=test_artifact,
|
||||
)
|
||||
assert version == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_artifact_integration(self, mock_invocation_context):
|
||||
"""Test load_artifact to ensure credential methods follow same pattern."""
|
||||
# Mock artifact service
|
||||
artifact_service = AsyncMock()
|
||||
test_artifact = Part.from_text(text="test content")
|
||||
artifact_service.load_artifact.return_value = test_artifact
|
||||
mock_invocation_context.artifact_service = artifact_service
|
||||
|
||||
context = CallbackContext(mock_invocation_context)
|
||||
|
||||
result = await context.load_artifact("test_file.txt")
|
||||
|
||||
artifact_service.load_artifact.assert_called_once_with(
|
||||
app_name="test-app",
|
||||
user_id="test-user",
|
||||
session_id="test-session-id",
|
||||
filename="test_file.txt",
|
||||
version=None,
|
||||
)
|
||||
assert result == test_artifact
|
||||
|
||||
@@ -167,7 +167,6 @@ class TestCredentialManager:
|
||||
|
||||
# Mock credential service
|
||||
credential_service = Mock()
|
||||
credential_service.load_credential = AsyncMock(return_value=mock_credential)
|
||||
|
||||
# Mock invocation context
|
||||
invocation_context = Mock()
|
||||
@@ -175,13 +174,12 @@ class TestCredentialManager:
|
||||
|
||||
callback_context = Mock()
|
||||
callback_context._invocation_context = invocation_context
|
||||
callback_context.load_credential = AsyncMock(return_value=mock_credential)
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
result = await manager._load_from_credential_service(callback_context)
|
||||
|
||||
credential_service.load_credential.assert_called_once_with(
|
||||
auth_config, callback_context
|
||||
)
|
||||
callback_context.load_credential.assert_called_once_with(auth_config)
|
||||
assert result == mock_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -216,13 +214,12 @@ class TestCredentialManager:
|
||||
|
||||
callback_context = Mock()
|
||||
callback_context._invocation_context = invocation_context
|
||||
callback_context.save_credential = AsyncMock()
|
||||
|
||||
manager = CredentialManager(auth_config)
|
||||
await manager._save_credential(callback_context, mock_credential)
|
||||
|
||||
credential_service.save_credential.assert_called_once_with(
|
||||
auth_config, callback_context
|
||||
)
|
||||
callback_context.save_credential.assert_called_once_with(auth_config)
|
||||
assert auth_config.exchanged_auth_credential == mock_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -242,9 +239,9 @@ class TestCredentialManager:
|
||||
manager = CredentialManager(auth_config)
|
||||
await manager._save_credential(callback_context, mock_credential)
|
||||
|
||||
# Should not raise an error, and credential should not be set in auth_config
|
||||
# when there's no credential service (according to implementation)
|
||||
assert auth_config.exchanged_auth_credential is None
|
||||
# Should not raise an error, and credential should be set in auth_config
|
||||
# even when there's no credential service (config is updated regardless)
|
||||
assert auth_config.exchanged_auth_credential == mock_credential
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_credential_oauth2(self):
|
||||
|
||||
Reference in New Issue
Block a user