diff --git a/contributing/samples/built_in_multi_tools/agent.py b/contributing/samples/built_in_multi_tools/agent.py index 5194e0ef..3eb9ce8b 100644 --- a/contributing/samples/built_in_multi_tools/agent.py +++ b/contributing/samples/built_in_multi_tools/agent.py @@ -17,7 +17,7 @@ import random from dotenv import load_dotenv from google.adk import Agent -from google.adk.tools.google_search_tool import google_search +from google.adk.tools.google_search_tool import GoogleSearchTool from google.adk.tools.tool_context import ToolContext from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool @@ -57,7 +57,9 @@ root_agent = Agent( """, tools=[ roll_die, - VertexAiSearchTool(data_store_id=VERTEXAI_DATASTORE_ID), - google_search, + VertexAiSearchTool( + data_store_id=VERTEXAI_DATASTORE_ID, bypass_multi_tools_limit=True + ), + GoogleSearchTool(bypass_multi_tools_limit=True), ], ) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f5b242b7..531a5034 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -438,6 +438,10 @@ class BaseLlmFlow(ABC): from ...agents.llm_agent import LlmAgent agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + raise TypeError( + f'Expected agent to be an LlmAgent, but got {type(agent)}' + ) # Runs processors. for processor in self.request_processors: @@ -468,7 +472,7 @@ class BaseLlmFlow(ABC): tools = await _convert_tool_union_to_tools( tool_union, ReadonlyContext(invocation_context), - llm_request.model, + agent.model, multiple_tools, ) for tool in tools: diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 47b4b00f..81ef925a 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -14,13 +14,13 @@ """Unit tests for BaseLlmFlow toolset integration.""" -from typing import Optional +from unittest import mock from unittest.mock import AsyncMock -from google.adk.agents.callback_context import CallbackContext from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin @@ -95,7 +95,6 @@ async def test_preprocess_calls_toolset_process_llm_request(): async def test_preprocess_handles_mixed_tools_and_toolsets(): """Test that _preprocess_async properly handles both tools and toolsets.""" from google.adk.tools.base_tool import BaseTool - from google.adk.tools.function_tool import FunctionTool # Create a mock tool class _MockTool(BaseTool): @@ -200,6 +199,46 @@ async def test_preprocess_with_google_search_workaround(): assert {d.name for d in declarations} == {'_my_tool', 'google_search_agent'} +@pytest.mark.asyncio +async def test_preprocess_calls_convert_tool_union_to_tools(): + """Test that _preprocess_async calls _convert_tool_union_to_tools.""" + + class _MockTool: + process_llm_request = AsyncMock() + + mock_tool_instance = _MockTool() + + def _my_tool(sides: int) -> int: + """A simple tool.""" + return sides + + with mock.patch( + 'google.adk.agents.llm_agent._convert_tool_union_to_tools', + new_callable=AsyncMock, + ) as mock_convert: + mock_convert.return_value = [mock_tool_instance] + + model = Gemini(model='gemini-2') + agent = Agent( + name='test_agent', model=model, tools=[_my_tool, google_search] + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='test message' + ) + flow = BaseLlmFlowForTesting() + llm_request = LlmRequest(model='gemini-2') + + async for _ in flow._preprocess_async(invocation_context, llm_request): + pass + + mock_convert.assert_called_with( + google_search, + mock.ANY, # ReadonlyContext(invocation_context) + model, + True, # multiple_tools + ) + + # TODO(b/448114567): Remove the following # test_handle_after_model_callback_grounding tests once the workaround # is no longer needed.