refactor: Add save_credential and load_credential in callback context for developer to access credential service

developer may want to save/load credentials themselves to/from credential service. see (https://github.com/google/adk-python/issues/1816)

PiperOrigin-RevId: 783487628
This commit is contained in:
Xiang (Sean) Zhou
2025-07-15 15:08:19 -07:00
committed by Copybara-Service
parent b1fa383e73
commit b977d12ea8
4 changed files with 198 additions and 20 deletions
+31
View File
@@ -24,6 +24,8 @@ from .readonly_context import ReadonlyContext
if TYPE_CHECKING:
from google.genai import types
from ..auth.auth_credential import AuthCredential
from ..auth.auth_tool import AuthConfig
from ..events.event_actions import EventActions
from ..sessions.state import State
from .invocation_context import InvocationContext
@@ -115,3 +117,32 @@ class CallbackContext(ReadonlyContext):
user_id=self._invocation_context.user_id,
session_id=self._invocation_context.session.id,
)
async def save_credential(self, auth_config: AuthConfig) -> None:
"""Saves a credential to the credential service.
Args:
auth_config: The authentication configuration containing the credential.
"""
if self._invocation_context.credential_service is None:
raise ValueError("Credential service is not initialized.")
await self._invocation_context.credential_service.save_credential(
auth_config, self
)
async def load_credential(
self, auth_config: AuthConfig
) -> Optional[AuthCredential]:
"""Loads a credential from the credential service.
Args:
auth_config: The authentication configuration for the credential.
Returns:
The loaded credential, or None if not found.
"""
if self._invocation_context.credential_service is None:
raise ValueError("Credential service is not initialized.")
return await self._invocation_context.credential_service.load_credential(
auth_config, self
)
+5 -8
View File
@@ -168,9 +168,7 @@ class CredentialManager:
if credential_service:
# Note: This should be made async in a future refactor
# For now, assuming synchronous operation
return await credential_service.load_credential(
self._auth_config, callback_context
)
return await callback_context.load_credential(self._auth_config)
return None
async def _load_from_auth_response(
@@ -255,10 +253,9 @@ class CredentialManager:
self, callback_context: CallbackContext, credential: AuthCredential
) -> None:
"""Save credential to credential service if available."""
# Update the exchanged credential in config
self._auth_config.exchanged_auth_credential = credential
credential_service = callback_context._invocation_context.credential_service
if credential_service:
# Update the exchanged credential in config
self._auth_config.exchanged_auth_credential = credential
await credential_service.save_credential(
self._auth_config, callback_context
)
await callback_context.save_credential(self._auth_config)
+155 -2
View File
@@ -16,9 +16,14 @@
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import Mock
from google.adk.agents.callback_context import CallbackContext
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_tool import AuthConfig
from google.adk.tools.tool_context import ToolContext
from google.genai.types import Part
import pytest
@@ -32,6 +37,8 @@ def mock_invocation_context():
mock_context.session.id = "test-session-id"
mock_context.app_name = "test-app"
mock_context.user_id = "test-user"
mock_context.artifact_service = None
mock_context.credential_service = None
return mock_context
@@ -63,6 +70,21 @@ def callback_context_without_artifact_service(mock_invocation_context):
return CallbackContext(mock_invocation_context)
@pytest.fixture
def mock_auth_config():
"""Create a mock auth config for testing."""
mock_config = Mock(spec=AuthConfig)
return mock_config
@pytest.fixture
def mock_auth_credential():
"""Create a mock auth credential for testing."""
mock_credential = Mock(spec=AuthCredential)
mock_credential.auth_type = AuthCredentialTypes.OAUTH2
return mock_credential
class TestCallbackContextListArtifacts:
"""Test the list_artifacts method in CallbackContext."""
@@ -119,8 +141,8 @@ class TestCallbackContextListArtifacts:
await callback_context_with_artifact_service.list_artifacts()
class TestToolContextListArtifacts:
"""Test that list_artifacts is available in ToolContext through inheritance."""
class TestCallbackContext:
"""Test suite for CallbackContext."""
@pytest.mark.asyncio
async def test_tool_context_inherits_list_artifacts(
@@ -167,3 +189,134 @@ class TestToolContextListArtifacts:
):
"""Test that ToolContext and CallbackContext share the same list_artifacts method."""
assert ToolContext.list_artifacts is CallbackContext.list_artifacts
def test_initialization(self, mock_invocation_context):
"""Test CallbackContext initialization."""
context = CallbackContext(mock_invocation_context)
assert context._invocation_context == mock_invocation_context
assert context._event_actions is not None
assert context._state is not None
@pytest.mark.asyncio
async def test_save_credential_with_service(
self, mock_invocation_context, mock_auth_config
):
"""Test save_credential when credential service is available."""
# Mock credential service
credential_service = AsyncMock()
mock_invocation_context.credential_service = credential_service
context = CallbackContext(mock_invocation_context)
await context.save_credential(mock_auth_config)
credential_service.save_credential.assert_called_once_with(
mock_auth_config, context
)
@pytest.mark.asyncio
async def test_save_credential_no_service(
self, mock_invocation_context, mock_auth_config
):
"""Test save_credential when credential service is not available."""
mock_invocation_context.credential_service = None
context = CallbackContext(mock_invocation_context)
with pytest.raises(
ValueError, match="Credential service is not initialized"
):
await context.save_credential(mock_auth_config)
@pytest.mark.asyncio
async def test_load_credential_with_service(
self, mock_invocation_context, mock_auth_config, mock_auth_credential
):
"""Test load_credential when credential service is available."""
# Mock credential service
credential_service = AsyncMock()
credential_service.load_credential.return_value = mock_auth_credential
mock_invocation_context.credential_service = credential_service
context = CallbackContext(mock_invocation_context)
result = await context.load_credential(mock_auth_config)
credential_service.load_credential.assert_called_once_with(
mock_auth_config, context
)
assert result == mock_auth_credential
@pytest.mark.asyncio
async def test_load_credential_no_service(
self, mock_invocation_context, mock_auth_config
):
"""Test load_credential when credential service is not available."""
mock_invocation_context.credential_service = None
context = CallbackContext(mock_invocation_context)
with pytest.raises(
ValueError, match="Credential service is not initialized"
):
await context.load_credential(mock_auth_config)
@pytest.mark.asyncio
async def test_load_credential_returns_none(
self, mock_invocation_context, mock_auth_config
):
"""Test load_credential returns None when credential not found."""
# Mock credential service
credential_service = AsyncMock()
credential_service.load_credential.return_value = None
mock_invocation_context.credential_service = credential_service
context = CallbackContext(mock_invocation_context)
result = await context.load_credential(mock_auth_config)
credential_service.load_credential.assert_called_once_with(
mock_auth_config, context
)
assert result is None
@pytest.mark.asyncio
async def test_save_artifact_integration(self, mock_invocation_context):
"""Test save_artifact to ensure credential methods follow same pattern."""
# Mock artifact service
artifact_service = AsyncMock()
artifact_service.save_artifact.return_value = 1
mock_invocation_context.artifact_service = artifact_service
context = CallbackContext(mock_invocation_context)
test_artifact = Part.from_text(text="test content")
version = await context.save_artifact("test_file.txt", test_artifact)
artifact_service.save_artifact.assert_called_once_with(
app_name="test-app",
user_id="test-user",
session_id="test-session-id",
filename="test_file.txt",
artifact=test_artifact,
)
assert version == 1
@pytest.mark.asyncio
async def test_load_artifact_integration(self, mock_invocation_context):
"""Test load_artifact to ensure credential methods follow same pattern."""
# Mock artifact service
artifact_service = AsyncMock()
test_artifact = Part.from_text(text="test content")
artifact_service.load_artifact.return_value = test_artifact
mock_invocation_context.artifact_service = artifact_service
context = CallbackContext(mock_invocation_context)
result = await context.load_artifact("test_file.txt")
artifact_service.load_artifact.assert_called_once_with(
app_name="test-app",
user_id="test-user",
session_id="test-session-id",
filename="test_file.txt",
version=None,
)
assert result == test_artifact
@@ -167,7 +167,6 @@ class TestCredentialManager:
# Mock credential service
credential_service = Mock()
credential_service.load_credential = AsyncMock(return_value=mock_credential)
# Mock invocation context
invocation_context = Mock()
@@ -175,13 +174,12 @@ class TestCredentialManager:
callback_context = Mock()
callback_context._invocation_context = invocation_context
callback_context.load_credential = AsyncMock(return_value=mock_credential)
manager = CredentialManager(auth_config)
result = await manager._load_from_credential_service(callback_context)
credential_service.load_credential.assert_called_once_with(
auth_config, callback_context
)
callback_context.load_credential.assert_called_once_with(auth_config)
assert result == mock_credential
@pytest.mark.asyncio
@@ -216,13 +214,12 @@ class TestCredentialManager:
callback_context = Mock()
callback_context._invocation_context = invocation_context
callback_context.save_credential = AsyncMock()
manager = CredentialManager(auth_config)
await manager._save_credential(callback_context, mock_credential)
credential_service.save_credential.assert_called_once_with(
auth_config, callback_context
)
callback_context.save_credential.assert_called_once_with(auth_config)
assert auth_config.exchanged_auth_credential == mock_credential
@pytest.mark.asyncio
@@ -242,9 +239,9 @@ class TestCredentialManager:
manager = CredentialManager(auth_config)
await manager._save_credential(callback_context, mock_credential)
# Should not raise an error, and credential should not be set in auth_config
# when there's no credential service (according to implementation)
assert auth_config.exchanged_auth_credential is None
# Should not raise an error, and credential should be set in auth_config
# even when there's no credential service (config is updated regardless)
assert auth_config.exchanged_auth_credential == mock_credential
@pytest.mark.asyncio
async def test_refresh_credential_oauth2(self):