You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: cache canonical tools to avoid multiple calls when streaming
Merge https://github.com/google/adk-python/pull/3299 Fixes https://github.com/google/adk-python/issues/3237 Co-authored-by: Xuan Yang <xygoogle@google.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3299 from hcadioli:fix/cache-tools de02bd3e4533c3741edf05788a5e8b2d3d38bae4 PiperOrigin-RevId: 829499299
This commit is contained in:
committed by
Copybara-Service
parent
9761fc6bbb
commit
8f3c3bfda5
@@ -32,6 +32,7 @@ from ..memory.base_memory_service import BaseMemoryService
|
||||
from ..plugins.plugin_manager import PluginManager
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from ..tools.base_tool import BaseTool
|
||||
from .active_streaming_tool import ActiveStreamingTool
|
||||
from .base_agent import BaseAgent
|
||||
from .base_agent import BaseAgentState
|
||||
@@ -202,6 +203,9 @@ class InvocationContext(BaseModel):
|
||||
plugin_manager: PluginManager = Field(default_factory=PluginManager)
|
||||
"""The manager for keeping track of plugins in this invocation."""
|
||||
|
||||
canonical_tools_cache: Optional[list[BaseTool]] = None
|
||||
"""The cache of canonical tools for this invocation."""
|
||||
|
||||
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
|
||||
default_factory=_InvocationCostManager
|
||||
)
|
||||
|
||||
@@ -855,7 +855,10 @@ class BaseLlmFlow(ABC):
|
||||
response: Optional[LlmResponse] = None,
|
||||
) -> Optional[LlmResponse]:
|
||||
readonly_context = ReadonlyContext(invocation_context)
|
||||
tools = await agent.canonical_tools(readonly_context)
|
||||
if (tools := invocation_context.canonical_tools_cache) is None:
|
||||
tools = await agent.canonical_tools(readonly_context)
|
||||
invocation_context.canonical_tools_cache = tools
|
||||
|
||||
if not any(tool.name == 'google_search_agent' for tool in tools):
|
||||
return response
|
||||
ground_metadata = invocation_context.session.state.get(
|
||||
|
||||
@@ -413,3 +413,76 @@ async def test_handle_after_model_callback_grounding_with_plugin_override(
|
||||
|
||||
assert result == plugin_response
|
||||
plugin.after_model_callback.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_after_model_callback_caches_canonical_tools():
|
||||
"""Test that canonical_tools is only called once per invocation_context."""
|
||||
canonical_tools_call_count = 0
|
||||
|
||||
async def mock_canonical_tools(self, readonly_context=None):
|
||||
nonlocal canonical_tools_call_count
|
||||
canonical_tools_call_count += 1
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
|
||||
class MockGoogleSearchTool(BaseTool):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name='google_search_agent', description='Mock search')
|
||||
|
||||
async def call(self, **kwargs):
|
||||
return 'mock result'
|
||||
|
||||
return [MockGoogleSearchTool()]
|
||||
|
||||
agent = Agent(name='test_agent', tools=[google_search, dummy_tool])
|
||||
|
||||
with mock.patch.object(
|
||||
type(agent), 'canonical_tools', new=mock_canonical_tools
|
||||
):
|
||||
invocation_context = await testing_utils.create_invocation_context(
|
||||
agent=agent
|
||||
)
|
||||
|
||||
assert invocation_context.canonical_tools_cache is None
|
||||
|
||||
invocation_context.session.state['temp:_adk_grounding_metadata'] = {
|
||||
'foo': 'bar'
|
||||
}
|
||||
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(parts=[types.Part.from_text(text='response')])
|
||||
)
|
||||
event = Event(
|
||||
id=Event.new_id(),
|
||||
invocation_id=invocation_context.invocation_id,
|
||||
author=agent.name,
|
||||
)
|
||||
flow = BaseLlmFlowForTesting()
|
||||
|
||||
# Call _handle_after_model_callback multiple times with the same context
|
||||
result1 = await flow._handle_after_model_callback(
|
||||
invocation_context, llm_response, event
|
||||
)
|
||||
result2 = await flow._handle_after_model_callback(
|
||||
invocation_context, llm_response, event
|
||||
)
|
||||
result3 = await flow._handle_after_model_callback(
|
||||
invocation_context, llm_response, event
|
||||
)
|
||||
|
||||
assert canonical_tools_call_count == 1, (
|
||||
'canonical_tools should be called once, but was called '
|
||||
f'{canonical_tools_call_count} times'
|
||||
)
|
||||
|
||||
assert invocation_context.canonical_tools_cache is not None
|
||||
assert len(invocation_context.canonical_tools_cache) == 1
|
||||
assert (
|
||||
invocation_context.canonical_tools_cache[0].name
|
||||
== 'google_search_agent'
|
||||
)
|
||||
|
||||
assert result1.grounding_metadata == {'foo': 'bar'}
|
||||
assert result2.grounding_metadata == {'foo': 'bar'}
|
||||
assert result3.grounding_metadata == {'foo': 'bar'}
|
||||
|
||||
Reference in New Issue
Block a user