fix!: Make credential manager to accept tool_context instead of callback_context

This seems a breaking change, but actually credential manager is used internally only and also it won't work if some one call it using callback context

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 863534476
This commit is contained in:
Xiang (Sean) Zhou
2026-01-30 21:59:15 -08:00
committed by Copybara-Service
parent 798d0053c8
commit fe82f3cde8
2 changed files with 54 additions and 56 deletions
+19 -19
View File
@@ -19,8 +19,8 @@ from typing import Optional
from fastapi.openapi.models import OAuth2
from ..agents.callback_context import CallbackContext
from ..tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger
from ..tools.tool_context import ToolContext
from ..utils.feature_decorator import experimental
from .auth_credential import AuthCredential
from .auth_credential import AuthCredentialTypes
@@ -72,7 +72,7 @@ class CredentialManager:
)
# Load and prepare credential
credential = await manager.load_auth_credential(callback_context)
credential = await manager.load_auth_credential(tool_context)
```
"""
@@ -124,11 +124,11 @@ class CredentialManager:
"""
self._exchanger_registry.register(credential_type, exchanger_instance)
async def request_credential(self, callback_context: CallbackContext) -> None:
callback_context.request_credential(self._auth_config)
async def request_credential(self, tool_context: ToolContext) -> None:
tool_context.request_credential(self._auth_config)
async def get_auth_credential(
self, callback_context: CallbackContext
self, tool_context: ToolContext
) -> Optional[AuthCredential]:
"""Load and prepare authentication credential through a structured workflow."""
@@ -140,14 +140,14 @@ class CredentialManager:
return self._auth_config.raw_auth_credential
# Step 3: Try to load existing processed credential
credential = await self._load_existing_credential(callback_context)
credential = await self._load_existing_credential(tool_context)
# Step 4: If no existing credential, load from auth response
# TODO instead of load from auth response, we can store auth response in
# credential service.
was_from_auth_response = False
if not credential:
credential = await self._load_from_auth_response(callback_context)
credential = await self._load_from_auth_response(tool_context)
was_from_auth_response = True
# Step 5: If still no credential available, check if client credentials
@@ -169,38 +169,38 @@ class CredentialManager:
# Step 8: Save credential if it was modified
if was_from_auth_response or was_exchanged or was_refreshed:
await self._save_credential(callback_context, credential)
await self._save_credential(tool_context, credential)
return credential
async def _load_existing_credential(
self, callback_context: CallbackContext
self, tool_context: ToolContext
) -> Optional[AuthCredential]:
"""Load existing credential from credential service."""
# Try loading from credential service first
credential = await self._load_from_credential_service(callback_context)
credential = await self._load_from_credential_service(tool_context)
if credential:
return credential
return None
async def _load_from_credential_service(
self, callback_context: CallbackContext
self, tool_context: ToolContext
) -> Optional[AuthCredential]:
"""Load credential from credential service if available."""
credential_service = callback_context._invocation_context.credential_service
credential_service = tool_context._invocation_context.credential_service
if credential_service:
# Note: This should be made async in a future refactor
# For now, assuming synchronous operation
return await callback_context.load_credential(self._auth_config)
return await tool_context.load_credential(self._auth_config)
return None
async def _load_from_auth_response(
self, callback_context: CallbackContext
self, tool_context: ToolContext
) -> Optional[AuthCredential]:
"""Load credential from auth response in callback context."""
return callback_context.get_auth_response(self._auth_config)
"""Load credential from auth response in tool context."""
return tool_context.get_auth_response(self._auth_config)
async def _exchange_credential(
self, credential: AuthCredential
@@ -290,15 +290,15 @@ class CredentialManager:
# Additional validation can be added here
async def _save_credential(
self, callback_context: CallbackContext, credential: AuthCredential
self, tool_context: ToolContext, credential: AuthCredential
) -> None:
"""Save credential to credential service if available."""
# Update the exchanged credential in config
self._auth_config.exchanged_auth_credential = credential
credential_service = callback_context._invocation_context.credential_service
credential_service = tool_context._invocation_context.credential_service
if credential_service:
await callback_context.save_credential(self._auth_config)
await tool_context.save_credential(self._auth_config)
async def _populate_auth_scheme(self) -> bool:
"""Auto-discover server metadata and populate missing auth scheme info.
+35 -37
View File
@@ -49,13 +49,13 @@ class TestCredentialManager:
async def test_request_credential(self):
"""Test request_credential method."""
auth_config = Mock(spec=AuthConfig)
callback_context = Mock()
callback_context.request_credential = Mock()
tool_context = Mock()
tool_context.request_credential = Mock()
manager = CredentialManager(auth_config)
await manager.request_credential(callback_context)
await manager.request_credential(tool_context)
callback_context.request_credential.assert_called_once_with(auth_config)
tool_context.request_credential.assert_called_once_with(auth_config)
@pytest.mark.asyncio
async def test_load_auth_credentials_success(self):
@@ -69,7 +69,7 @@ class TestCredentialManager:
mock_credential = Mock(spec=AuthCredential)
mock_credential.auth_type = AuthCredentialTypes.API_KEY
callback_context = Mock()
tool_context = Mock()
manager = CredentialManager(auth_config)
@@ -86,17 +86,17 @@ class TestCredentialManager:
)
manager._save_credential = AsyncMock()
result = await manager.get_auth_credential(callback_context)
result = await manager.get_auth_credential(tool_context)
# Verify all methods were called
manager._validate_credential.assert_called_once()
manager._is_credential_ready.assert_called_once()
manager._load_existing_credential.assert_called_once_with(callback_context)
manager._load_from_auth_response.assert_called_once_with(callback_context)
manager._load_existing_credential.assert_called_once_with(tool_context)
manager._load_from_auth_response.assert_called_once_with(tool_context)
manager._exchange_credential.assert_called_once_with(mock_credential)
manager._refresh_credential.assert_called_once_with(mock_credential)
manager._save_credential.assert_called_once_with(
callback_context, mock_credential
tool_context, mock_credential
)
assert result == mock_credential
@@ -111,7 +111,7 @@ class TestCredentialManager:
auth_config.auth_scheme = Mock()
auth_config.auth_scheme.flows = None
callback_context = Mock()
tool_context = Mock()
manager = CredentialManager(auth_config)
@@ -121,13 +121,13 @@ class TestCredentialManager:
manager._load_existing_credential = AsyncMock(return_value=None)
manager._load_from_auth_response = AsyncMock(return_value=None)
result = await manager.get_auth_credential(callback_context)
result = await manager.get_auth_credential(tool_context)
# Verify methods were called but no credential returned
manager._validate_credential.assert_called_once()
manager._is_credential_ready.assert_called_once()
manager._load_existing_credential.assert_called_once_with(callback_context)
manager._load_from_auth_response.assert_called_once_with(callback_context)
manager._load_existing_credential.assert_called_once_with(tool_context)
manager._load_from_auth_response.assert_called_once_with(tool_context)
assert result is None
@@ -138,12 +138,12 @@ class TestCredentialManager:
mock_credential = Mock(spec=AuthCredential)
auth_config.exchanged_auth_credential = mock_credential
callback_context = Mock()
tool_context = Mock()
manager = CredentialManager(auth_config)
manager._load_from_credential_service = AsyncMock(return_value=None)
result = await manager._load_existing_credential(callback_context)
result = await manager._load_existing_credential(tool_context)
assert result is None
@@ -155,23 +155,21 @@ class TestCredentialManager:
mock_credential = Mock(spec=AuthCredential)
callback_context = Mock()
tool_context = Mock()
manager = CredentialManager(auth_config)
manager._load_from_credential_service = AsyncMock(
return_value=mock_credential
)
result = await manager._load_existing_credential(callback_context)
result = await manager._load_existing_credential(tool_context)
manager._load_from_credential_service.assert_called_once_with(
callback_context
)
manager._load_from_credential_service.assert_called_once_with(tool_context)
assert result == mock_credential
@pytest.mark.asyncio
async def test_load_from_credential_service_with_service(self):
"""Test _load_from_credential_service from callback context when credential service is available."""
"""Test _load_from_credential_service from tool context when credential service is available."""
auth_config = Mock(spec=AuthConfig)
mock_credential = Mock(spec=AuthCredential)
@@ -183,14 +181,14 @@ class TestCredentialManager:
invocation_context = Mock()
invocation_context.credential_service = credential_service
callback_context = Mock()
callback_context._invocation_context = invocation_context
callback_context.load_credential = AsyncMock(return_value=mock_credential)
tool_context = Mock()
tool_context._invocation_context = invocation_context
tool_context.load_credential = AsyncMock(return_value=mock_credential)
manager = CredentialManager(auth_config)
result = await manager._load_from_credential_service(callback_context)
result = await manager._load_from_credential_service(tool_context)
callback_context.load_credential.assert_called_once_with(auth_config)
tool_context.load_credential.assert_called_once_with(auth_config)
assert result == mock_credential
@pytest.mark.asyncio
@@ -202,11 +200,11 @@ class TestCredentialManager:
invocation_context = Mock()
invocation_context.credential_service = None
callback_context = Mock()
callback_context._invocation_context = invocation_context
tool_context = Mock()
tool_context._invocation_context = invocation_context
manager = CredentialManager(auth_config)
result = await manager._load_from_credential_service(callback_context)
result = await manager._load_from_credential_service(tool_context)
assert result is None
@@ -223,14 +221,14 @@ class TestCredentialManager:
invocation_context = Mock()
invocation_context.credential_service = credential_service
callback_context = Mock()
callback_context._invocation_context = invocation_context
callback_context.save_credential = AsyncMock()
tool_context = Mock()
tool_context._invocation_context = invocation_context
tool_context.save_credential = AsyncMock()
manager = CredentialManager(auth_config)
await manager._save_credential(callback_context, mock_credential)
await manager._save_credential(tool_context, mock_credential)
callback_context.save_credential.assert_called_once_with(auth_config)
tool_context.save_credential.assert_called_once_with(auth_config)
assert auth_config.exchanged_auth_credential == mock_credential
@pytest.mark.asyncio
@@ -244,11 +242,11 @@ class TestCredentialManager:
invocation_context = Mock()
invocation_context.credential_service = None
callback_context = Mock()
callback_context._invocation_context = invocation_context
tool_context = Mock()
tool_context._invocation_context = invocation_context
manager = CredentialManager(auth_config)
await manager._save_credential(callback_context, mock_credential)
await manager._save_credential(tool_context, mock_credential)
# Should not raise an error, and credential should be set in auth_config
# even when there's no credential service (config is updated regardless)