From b977d12ea84eeac17e997ff3ac851b96dba85f3d Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 15 Jul 2025 15:08:19 -0700 Subject: [PATCH] refactor: Add save_credential and load_credential in callback context for developer to access credential service developer may want to save/load credentials themselves to/from credential service. see (https://github.com/google/adk-python/issues/1816) PiperOrigin-RevId: 783487628 --- src/google/adk/agents/callback_context.py | 31 ++++ src/google/adk/auth/credential_manager.py | 13 +- .../unittests/agents/test_callback_context.py | 157 +++++++++++++++++- .../unittests/auth/test_credential_manager.py | 17 +- 4 files changed, 198 insertions(+), 20 deletions(-) diff --git a/src/google/adk/agents/callback_context.py b/src/google/adk/agents/callback_context.py index 522c9ef9..f42d344a 100644 --- a/src/google/adk/agents/callback_context.py +++ b/src/google/adk/agents/callback_context.py @@ -24,6 +24,8 @@ from .readonly_context import ReadonlyContext if TYPE_CHECKING: from google.genai import types + from ..auth.auth_credential import AuthCredential + from ..auth.auth_tool import AuthConfig from ..events.event_actions import EventActions from ..sessions.state import State from .invocation_context import InvocationContext @@ -115,3 +117,32 @@ class CallbackContext(ReadonlyContext): user_id=self._invocation_context.user_id, session_id=self._invocation_context.session.id, ) + + async def save_credential(self, auth_config: AuthConfig) -> None: + """Saves a credential to the credential service. + + Args: + auth_config: The authentication configuration containing the credential. + """ + if self._invocation_context.credential_service is None: + raise ValueError("Credential service is not initialized.") + await self._invocation_context.credential_service.save_credential( + auth_config, self + ) + + async def load_credential( + self, auth_config: AuthConfig + ) -> Optional[AuthCredential]: + """Loads a credential from the credential service. + + Args: + auth_config: The authentication configuration for the credential. + + Returns: + The loaded credential, or None if not found. + """ + if self._invocation_context.credential_service is None: + raise ValueError("Credential service is not initialized.") + return await self._invocation_context.credential_service.load_credential( + auth_config, self + ) diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index 5cf20366..c5dae9f5 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -168,9 +168,7 @@ class CredentialManager: if credential_service: # Note: This should be made async in a future refactor # For now, assuming synchronous operation - return await credential_service.load_credential( - self._auth_config, callback_context - ) + return await callback_context.load_credential(self._auth_config) return None async def _load_from_auth_response( @@ -255,10 +253,9 @@ class CredentialManager: self, callback_context: CallbackContext, credential: AuthCredential ) -> None: """Save credential to credential service if available.""" + # Update the exchanged credential in config + self._auth_config.exchanged_auth_credential = credential + credential_service = callback_context._invocation_context.credential_service if credential_service: - # Update the exchanged credential in config - self._auth_config.exchanged_auth_credential = credential - await credential_service.save_credential( - self._auth_config, callback_context - ) + await callback_context.save_credential(self._auth_config) diff --git a/tests/unittests/agents/test_callback_context.py b/tests/unittests/agents/test_callback_context.py index 4acb6d2d..fb8b2ae7 100644 --- a/tests/unittests/agents/test_callback_context.py +++ b/tests/unittests/agents/test_callback_context.py @@ -16,9 +16,14 @@ from unittest.mock import AsyncMock from unittest.mock import MagicMock +from unittest.mock import Mock from google.adk.agents.callback_context import CallbackContext +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_tool import AuthConfig from google.adk.tools.tool_context import ToolContext +from google.genai.types import Part import pytest @@ -32,6 +37,8 @@ def mock_invocation_context(): mock_context.session.id = "test-session-id" mock_context.app_name = "test-app" mock_context.user_id = "test-user" + mock_context.artifact_service = None + mock_context.credential_service = None return mock_context @@ -63,6 +70,21 @@ def callback_context_without_artifact_service(mock_invocation_context): return CallbackContext(mock_invocation_context) +@pytest.fixture +def mock_auth_config(): + """Create a mock auth config for testing.""" + mock_config = Mock(spec=AuthConfig) + return mock_config + + +@pytest.fixture +def mock_auth_credential(): + """Create a mock auth credential for testing.""" + mock_credential = Mock(spec=AuthCredential) + mock_credential.auth_type = AuthCredentialTypes.OAUTH2 + return mock_credential + + class TestCallbackContextListArtifacts: """Test the list_artifacts method in CallbackContext.""" @@ -119,8 +141,8 @@ class TestCallbackContextListArtifacts: await callback_context_with_artifact_service.list_artifacts() -class TestToolContextListArtifacts: - """Test that list_artifacts is available in ToolContext through inheritance.""" +class TestCallbackContext: + """Test suite for CallbackContext.""" @pytest.mark.asyncio async def test_tool_context_inherits_list_artifacts( @@ -167,3 +189,134 @@ class TestToolContextListArtifacts: ): """Test that ToolContext and CallbackContext share the same list_artifacts method.""" assert ToolContext.list_artifacts is CallbackContext.list_artifacts + + def test_initialization(self, mock_invocation_context): + """Test CallbackContext initialization.""" + context = CallbackContext(mock_invocation_context) + assert context._invocation_context == mock_invocation_context + assert context._event_actions is not None + assert context._state is not None + + @pytest.mark.asyncio + async def test_save_credential_with_service( + self, mock_invocation_context, mock_auth_config + ): + """Test save_credential when credential service is available.""" + # Mock credential service + credential_service = AsyncMock() + mock_invocation_context.credential_service = credential_service + + context = CallbackContext(mock_invocation_context) + await context.save_credential(mock_auth_config) + + credential_service.save_credential.assert_called_once_with( + mock_auth_config, context + ) + + @pytest.mark.asyncio + async def test_save_credential_no_service( + self, mock_invocation_context, mock_auth_config + ): + """Test save_credential when credential service is not available.""" + mock_invocation_context.credential_service = None + + context = CallbackContext(mock_invocation_context) + + with pytest.raises( + ValueError, match="Credential service is not initialized" + ): + await context.save_credential(mock_auth_config) + + @pytest.mark.asyncio + async def test_load_credential_with_service( + self, mock_invocation_context, mock_auth_config, mock_auth_credential + ): + """Test load_credential when credential service is available.""" + # Mock credential service + credential_service = AsyncMock() + credential_service.load_credential.return_value = mock_auth_credential + mock_invocation_context.credential_service = credential_service + + context = CallbackContext(mock_invocation_context) + result = await context.load_credential(mock_auth_config) + + credential_service.load_credential.assert_called_once_with( + mock_auth_config, context + ) + assert result == mock_auth_credential + + @pytest.mark.asyncio + async def test_load_credential_no_service( + self, mock_invocation_context, mock_auth_config + ): + """Test load_credential when credential service is not available.""" + mock_invocation_context.credential_service = None + + context = CallbackContext(mock_invocation_context) + + with pytest.raises( + ValueError, match="Credential service is not initialized" + ): + await context.load_credential(mock_auth_config) + + @pytest.mark.asyncio + async def test_load_credential_returns_none( + self, mock_invocation_context, mock_auth_config + ): + """Test load_credential returns None when credential not found.""" + # Mock credential service + credential_service = AsyncMock() + credential_service.load_credential.return_value = None + mock_invocation_context.credential_service = credential_service + + context = CallbackContext(mock_invocation_context) + result = await context.load_credential(mock_auth_config) + + credential_service.load_credential.assert_called_once_with( + mock_auth_config, context + ) + assert result is None + + @pytest.mark.asyncio + async def test_save_artifact_integration(self, mock_invocation_context): + """Test save_artifact to ensure credential methods follow same pattern.""" + # Mock artifact service + artifact_service = AsyncMock() + artifact_service.save_artifact.return_value = 1 + mock_invocation_context.artifact_service = artifact_service + + context = CallbackContext(mock_invocation_context) + test_artifact = Part.from_text(text="test content") + + version = await context.save_artifact("test_file.txt", test_artifact) + + artifact_service.save_artifact.assert_called_once_with( + app_name="test-app", + user_id="test-user", + session_id="test-session-id", + filename="test_file.txt", + artifact=test_artifact, + ) + assert version == 1 + + @pytest.mark.asyncio + async def test_load_artifact_integration(self, mock_invocation_context): + """Test load_artifact to ensure credential methods follow same pattern.""" + # Mock artifact service + artifact_service = AsyncMock() + test_artifact = Part.from_text(text="test content") + artifact_service.load_artifact.return_value = test_artifact + mock_invocation_context.artifact_service = artifact_service + + context = CallbackContext(mock_invocation_context) + + result = await context.load_artifact("test_file.txt") + + artifact_service.load_artifact.assert_called_once_with( + app_name="test-app", + user_id="test-user", + session_id="test-session-id", + filename="test_file.txt", + version=None, + ) + assert result == test_artifact diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index 398ee938..fd978604 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -167,7 +167,6 @@ class TestCredentialManager: # Mock credential service credential_service = Mock() - credential_service.load_credential = AsyncMock(return_value=mock_credential) # Mock invocation context invocation_context = Mock() @@ -175,13 +174,12 @@ class TestCredentialManager: callback_context = Mock() callback_context._invocation_context = invocation_context + callback_context.load_credential = AsyncMock(return_value=mock_credential) manager = CredentialManager(auth_config) result = await manager._load_from_credential_service(callback_context) - credential_service.load_credential.assert_called_once_with( - auth_config, callback_context - ) + callback_context.load_credential.assert_called_once_with(auth_config) assert result == mock_credential @pytest.mark.asyncio @@ -216,13 +214,12 @@ class TestCredentialManager: callback_context = Mock() callback_context._invocation_context = invocation_context + callback_context.save_credential = AsyncMock() manager = CredentialManager(auth_config) await manager._save_credential(callback_context, mock_credential) - credential_service.save_credential.assert_called_once_with( - auth_config, callback_context - ) + callback_context.save_credential.assert_called_once_with(auth_config) assert auth_config.exchanged_auth_credential == mock_credential @pytest.mark.asyncio @@ -242,9 +239,9 @@ class TestCredentialManager: manager = CredentialManager(auth_config) await manager._save_credential(callback_context, mock_credential) - # Should not raise an error, and credential should not be set in auth_config - # when there's no credential service (according to implementation) - assert auth_config.exchanged_auth_credential is None + # Should not raise an error, and credential should be set in auth_config + # even when there's no credential service (config is updated regardless) + assert auth_config.exchanged_auth_credential == mock_credential @pytest.mark.asyncio async def test_refresh_credential_oauth2(self):