fix: Use the agent's model when creating Google search agent tool

PiperOrigin-RevId: 819980797
This commit is contained in:
Xuan Yang
2025-10-15 17:40:49 -07:00
committed by Copybara-Service
parent 86097afe49
commit 3734ceaa6c
3 changed files with 52 additions and 7 deletions
@@ -17,7 +17,7 @@ import random
from dotenv import load_dotenv
from google.adk import Agent
from google.adk.tools.google_search_tool import google_search
from google.adk.tools.google_search_tool import GoogleSearchTool
from google.adk.tools.tool_context import ToolContext
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
@@ -57,7 +57,9 @@ root_agent = Agent(
""",
tools=[
roll_die,
VertexAiSearchTool(data_store_id=VERTEXAI_DATASTORE_ID),
google_search,
VertexAiSearchTool(
data_store_id=VERTEXAI_DATASTORE_ID, bypass_multi_tools_limit=True
),
GoogleSearchTool(bypass_multi_tools_limit=True),
],
)
@@ -438,6 +438,10 @@ class BaseLlmFlow(ABC):
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
raise TypeError(
f'Expected agent to be an LlmAgent, but got {type(agent)}'
)
# Runs processors.
for processor in self.request_processors:
@@ -468,7 +472,7 @@ class BaseLlmFlow(ABC):
tools = await _convert_tool_union_to_tools(
tool_union,
ReadonlyContext(invocation_context),
llm_request.model,
agent.model,
multiple_tools,
)
for tool in tools:
@@ -14,13 +14,13 @@
"""Unit tests for BaseLlmFlow toolset integration."""
from typing import Optional
from unittest import mock
from unittest.mock import AsyncMock
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.events.event import Event
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.models.google_llm import Gemini
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
@@ -95,7 +95,6 @@ async def test_preprocess_calls_toolset_process_llm_request():
async def test_preprocess_handles_mixed_tools_and_toolsets():
"""Test that _preprocess_async properly handles both tools and toolsets."""
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.function_tool import FunctionTool
# Create a mock tool
class _MockTool(BaseTool):
@@ -200,6 +199,46 @@ async def test_preprocess_with_google_search_workaround():
assert {d.name for d in declarations} == {'_my_tool', 'google_search_agent'}
@pytest.mark.asyncio
async def test_preprocess_calls_convert_tool_union_to_tools():
"""Test that _preprocess_async calls _convert_tool_union_to_tools."""
class _MockTool:
process_llm_request = AsyncMock()
mock_tool_instance = _MockTool()
def _my_tool(sides: int) -> int:
"""A simple tool."""
return sides
with mock.patch(
'google.adk.agents.llm_agent._convert_tool_union_to_tools',
new_callable=AsyncMock,
) as mock_convert:
mock_convert.return_value = [mock_tool_instance]
model = Gemini(model='gemini-2')
agent = Agent(
name='test_agent', model=model, tools=[_my_tool, google_search]
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)
flow = BaseLlmFlowForTesting()
llm_request = LlmRequest(model='gemini-2')
async for _ in flow._preprocess_async(invocation_context, llm_request):
pass
mock_convert.assert_called_with(
google_search,
mock.ANY, # ReadonlyContext(invocation_context)
model,
True, # multiple_tools
)
# TODO(b/448114567): Remove the following
# test_handle_after_model_callback_grounding tests once the workaround
# is no longer needed.