From fe82f3cde854e49be13d90b4c02d786d82f8a202 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Fri, 30 Jan 2026 21:59:15 -0800 Subject: [PATCH] fix!: Make credential manager to accept tool_context instead of callback_context This seems a breaking change, but actually credential manager is used internally only and also it won't work if some one call it using callback context Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 863534476 --- src/google/adk/auth/credential_manager.py | 38 +++++----- .../unittests/auth/test_credential_manager.py | 72 +++++++++---------- 2 files changed, 54 insertions(+), 56 deletions(-) diff --git a/src/google/adk/auth/credential_manager.py b/src/google/adk/auth/credential_manager.py index 0b6efd81..bf0ed1e2 100644 --- a/src/google/adk/auth/credential_manager.py +++ b/src/google/adk/auth/credential_manager.py @@ -19,8 +19,8 @@ from typing import Optional from fastapi.openapi.models import OAuth2 -from ..agents.callback_context import CallbackContext from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger +from ..tools.tool_context import ToolContext from ..utils.feature_decorator import experimental from .auth_credential import AuthCredential from .auth_credential import AuthCredentialTypes @@ -72,7 +72,7 @@ class CredentialManager: ) # Load and prepare credential - credential = await manager.load_auth_credential(callback_context) + credential = await manager.load_auth_credential(tool_context) ``` """ @@ -124,11 +124,11 @@ class CredentialManager: """ self._exchanger_registry.register(credential_type, exchanger_instance) - async def request_credential(self, callback_context: CallbackContext) -> None: - callback_context.request_credential(self._auth_config) + async def request_credential(self, tool_context: ToolContext) -> None: + tool_context.request_credential(self._auth_config) async def get_auth_credential( - self, callback_context: CallbackContext + self, tool_context: ToolContext ) -> Optional[AuthCredential]: """Load and prepare authentication credential through a structured workflow.""" @@ -140,14 +140,14 @@ class CredentialManager: return self._auth_config.raw_auth_credential # Step 3: Try to load existing processed credential - credential = await self._load_existing_credential(callback_context) + credential = await self._load_existing_credential(tool_context) # Step 4: If no existing credential, load from auth response # TODO instead of load from auth response, we can store auth response in # credential service. was_from_auth_response = False if not credential: - credential = await self._load_from_auth_response(callback_context) + credential = await self._load_from_auth_response(tool_context) was_from_auth_response = True # Step 5: If still no credential available, check if client credentials @@ -169,38 +169,38 @@ class CredentialManager: # Step 8: Save credential if it was modified if was_from_auth_response or was_exchanged or was_refreshed: - await self._save_credential(callback_context, credential) + await self._save_credential(tool_context, credential) return credential async def _load_existing_credential( - self, callback_context: CallbackContext + self, tool_context: ToolContext ) -> Optional[AuthCredential]: """Load existing credential from credential service.""" # Try loading from credential service first - credential = await self._load_from_credential_service(callback_context) + credential = await self._load_from_credential_service(tool_context) if credential: return credential return None async def _load_from_credential_service( - self, callback_context: CallbackContext + self, tool_context: ToolContext ) -> Optional[AuthCredential]: """Load credential from credential service if available.""" - credential_service = callback_context._invocation_context.credential_service + credential_service = tool_context._invocation_context.credential_service if credential_service: # Note: This should be made async in a future refactor # For now, assuming synchronous operation - return await callback_context.load_credential(self._auth_config) + return await tool_context.load_credential(self._auth_config) return None async def _load_from_auth_response( - self, callback_context: CallbackContext + self, tool_context: ToolContext ) -> Optional[AuthCredential]: - """Load credential from auth response in callback context.""" - return callback_context.get_auth_response(self._auth_config) + """Load credential from auth response in tool context.""" + return tool_context.get_auth_response(self._auth_config) async def _exchange_credential( self, credential: AuthCredential @@ -290,15 +290,15 @@ class CredentialManager: # Additional validation can be added here async def _save_credential( - self, callback_context: CallbackContext, credential: AuthCredential + self, tool_context: ToolContext, credential: AuthCredential ) -> None: """Save credential to credential service if available.""" # Update the exchanged credential in config self._auth_config.exchanged_auth_credential = credential - credential_service = callback_context._invocation_context.credential_service + credential_service = tool_context._invocation_context.credential_service if credential_service: - await callback_context.save_credential(self._auth_config) + await tool_context.save_credential(self._auth_config) async def _populate_auth_scheme(self) -> bool: """Auto-discover server metadata and populate missing auth scheme info. diff --git a/tests/unittests/auth/test_credential_manager.py b/tests/unittests/auth/test_credential_manager.py index 856780fc..95e01046 100644 --- a/tests/unittests/auth/test_credential_manager.py +++ b/tests/unittests/auth/test_credential_manager.py @@ -49,13 +49,13 @@ class TestCredentialManager: async def test_request_credential(self): """Test request_credential method.""" auth_config = Mock(spec=AuthConfig) - callback_context = Mock() - callback_context.request_credential = Mock() + tool_context = Mock() + tool_context.request_credential = Mock() manager = CredentialManager(auth_config) - await manager.request_credential(callback_context) + await manager.request_credential(tool_context) - callback_context.request_credential.assert_called_once_with(auth_config) + tool_context.request_credential.assert_called_once_with(auth_config) @pytest.mark.asyncio async def test_load_auth_credentials_success(self): @@ -69,7 +69,7 @@ class TestCredentialManager: mock_credential = Mock(spec=AuthCredential) mock_credential.auth_type = AuthCredentialTypes.API_KEY - callback_context = Mock() + tool_context = Mock() manager = CredentialManager(auth_config) @@ -86,17 +86,17 @@ class TestCredentialManager: ) manager._save_credential = AsyncMock() - result = await manager.get_auth_credential(callback_context) + result = await manager.get_auth_credential(tool_context) # Verify all methods were called manager._validate_credential.assert_called_once() manager._is_credential_ready.assert_called_once() - manager._load_existing_credential.assert_called_once_with(callback_context) - manager._load_from_auth_response.assert_called_once_with(callback_context) + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) manager._exchange_credential.assert_called_once_with(mock_credential) manager._refresh_credential.assert_called_once_with(mock_credential) manager._save_credential.assert_called_once_with( - callback_context, mock_credential + tool_context, mock_credential ) assert result == mock_credential @@ -111,7 +111,7 @@ class TestCredentialManager: auth_config.auth_scheme = Mock() auth_config.auth_scheme.flows = None - callback_context = Mock() + tool_context = Mock() manager = CredentialManager(auth_config) @@ -121,13 +121,13 @@ class TestCredentialManager: manager._load_existing_credential = AsyncMock(return_value=None) manager._load_from_auth_response = AsyncMock(return_value=None) - result = await manager.get_auth_credential(callback_context) + result = await manager.get_auth_credential(tool_context) # Verify methods were called but no credential returned manager._validate_credential.assert_called_once() manager._is_credential_ready.assert_called_once() - manager._load_existing_credential.assert_called_once_with(callback_context) - manager._load_from_auth_response.assert_called_once_with(callback_context) + manager._load_existing_credential.assert_called_once_with(tool_context) + manager._load_from_auth_response.assert_called_once_with(tool_context) assert result is None @@ -138,12 +138,12 @@ class TestCredentialManager: mock_credential = Mock(spec=AuthCredential) auth_config.exchanged_auth_credential = mock_credential - callback_context = Mock() + tool_context = Mock() manager = CredentialManager(auth_config) manager._load_from_credential_service = AsyncMock(return_value=None) - result = await manager._load_existing_credential(callback_context) + result = await manager._load_existing_credential(tool_context) assert result is None @@ -155,23 +155,21 @@ class TestCredentialManager: mock_credential = Mock(spec=AuthCredential) - callback_context = Mock() + tool_context = Mock() manager = CredentialManager(auth_config) manager._load_from_credential_service = AsyncMock( return_value=mock_credential ) - result = await manager._load_existing_credential(callback_context) + result = await manager._load_existing_credential(tool_context) - manager._load_from_credential_service.assert_called_once_with( - callback_context - ) + manager._load_from_credential_service.assert_called_once_with(tool_context) assert result == mock_credential @pytest.mark.asyncio async def test_load_from_credential_service_with_service(self): - """Test _load_from_credential_service from callback context when credential service is available.""" + """Test _load_from_credential_service from tool context when credential service is available.""" auth_config = Mock(spec=AuthConfig) mock_credential = Mock(spec=AuthCredential) @@ -183,14 +181,14 @@ class TestCredentialManager: invocation_context = Mock() invocation_context.credential_service = credential_service - callback_context = Mock() - callback_context._invocation_context = invocation_context - callback_context.load_credential = AsyncMock(return_value=mock_credential) + tool_context = Mock() + tool_context._invocation_context = invocation_context + tool_context.load_credential = AsyncMock(return_value=mock_credential) manager = CredentialManager(auth_config) - result = await manager._load_from_credential_service(callback_context) + result = await manager._load_from_credential_service(tool_context) - callback_context.load_credential.assert_called_once_with(auth_config) + tool_context.load_credential.assert_called_once_with(auth_config) assert result == mock_credential @pytest.mark.asyncio @@ -202,11 +200,11 @@ class TestCredentialManager: invocation_context = Mock() invocation_context.credential_service = None - callback_context = Mock() - callback_context._invocation_context = invocation_context + tool_context = Mock() + tool_context._invocation_context = invocation_context manager = CredentialManager(auth_config) - result = await manager._load_from_credential_service(callback_context) + result = await manager._load_from_credential_service(tool_context) assert result is None @@ -223,14 +221,14 @@ class TestCredentialManager: invocation_context = Mock() invocation_context.credential_service = credential_service - callback_context = Mock() - callback_context._invocation_context = invocation_context - callback_context.save_credential = AsyncMock() + tool_context = Mock() + tool_context._invocation_context = invocation_context + tool_context.save_credential = AsyncMock() manager = CredentialManager(auth_config) - await manager._save_credential(callback_context, mock_credential) + await manager._save_credential(tool_context, mock_credential) - callback_context.save_credential.assert_called_once_with(auth_config) + tool_context.save_credential.assert_called_once_with(auth_config) assert auth_config.exchanged_auth_credential == mock_credential @pytest.mark.asyncio @@ -244,11 +242,11 @@ class TestCredentialManager: invocation_context = Mock() invocation_context.credential_service = None - callback_context = Mock() - callback_context._invocation_context = invocation_context + tool_context = Mock() + tool_context._invocation_context = invocation_context manager = CredentialManager(auth_config) - await manager._save_credential(callback_context, mock_credential) + await manager._save_credential(tool_context, mock_credential) # Should not raise an error, and credential should be set in auth_config # even when there's no credential service (config is updated regardless)