diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 23450df1..8ebcade5 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -499,12 +499,6 @@ class LlmAgent(BaseAgent): ' sub_agents must be empty to disable agent transfer.' ) - if self.tools: - raise ValueError( - f'Invalid config for agent {self.name}: if output_schema is set,' - ' tools must be empty' - ) - @field_validator('generate_content_config', mode='after') @classmethod def __validate_generate_content_config( diff --git a/src/google/adk/flows/llm_flows/_output_schema_processor.py b/src/google/adk/flows/llm_flows/_output_schema_processor.py new file mode 100644 index 00000000..16638702 --- /dev/null +++ b/src/google/adk/flows/llm_flows/_output_schema_processor.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Handles output schema when tools are also present.""" + +from __future__ import annotations + +import json +from typing import AsyncGenerator + +from typing_extensions import override + +from ...agents.invocation_context import InvocationContext +from ...events.event import Event +from ...models.llm_request import LlmRequest +from ...tools.set_model_response_tool import SetModelResponseTool +from ._base_llm_processor import BaseLlmRequestProcessor + + +class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor): + """Processor that handles output schema for agents with tools.""" + + @override + async def run_async( + self, invocation_context: InvocationContext, llm_request: LlmRequest + ) -> AsyncGenerator[Event, None]: + from ...agents.llm_agent import LlmAgent + + agent = invocation_context.agent + if not isinstance(agent, LlmAgent): + return + + # Check if we need the processor: output_schema + tools + if not agent.output_schema or not agent.tools: + return + + # Add the set_model_response tool to handle structured output + set_response_tool = SetModelResponseTool(agent.output_schema) + llm_request.append_tools([set_response_tool]) + + # Add instruction about using the set_model_response tool + instruction = ( + 'IMPORTANT: You have access to other tools, but you must provide ' + 'your final response using the set_model_response tool with the ' + 'required structured format. After using any other tools needed ' + 'to complete the task, always call set_model_response with your ' + 'final answer in the specified schema format.' + ) + llm_request.append_instructions([instruction]) + + return + yield # Generator requires yield statement in function body. + + +def create_final_model_response_event( + invocation_context: InvocationContext, json_response: str +) -> Event: + """Create a final model response event from set_model_response JSON. + + Args: + invocation_context: The invocation context. + json_response: The JSON response from set_model_response tool. + + Returns: + A new Event that looks like a normal model response. + """ + from google.genai import types + + # Create a proper model response event + final_event = Event(author=invocation_context.agent.name) + final_event.content = types.Content( + role='model', parts=[types.Part(text=json_response)] + ) + return final_event + + +def get_structured_model_response(function_response_event: Event) -> str | None: + """Check if function response contains set_model_response and extract JSON. + + Args: + function_response_event: The function response event to check. + + Returns: + JSON response string if set_model_response was called, None otherwise. + """ + if ( + not function_response_event + or not function_response_event.get_function_responses() + ): + return None + + for func_response in function_response_event.get_function_responses(): + if func_response.name == 'set_model_response': + # Convert dict to JSON string + return json.dumps(func_response.response) + + return None + + +# Export the processors +request_processor = _OutputSchemaRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 0a1cdb91..90cf0fbc 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -28,6 +28,7 @@ from google.genai import types from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosedOK +from . import _output_schema_processor from . import functions from ...agents.base_agent import BaseAgent from ...agents.callback_context import CallbackContext @@ -500,8 +501,21 @@ class BaseLlmFlow(ABC): function_response_event = await functions.handle_function_calls_live( invocation_context, model_response_event, llm_request.tools_dict ) + # Always yield the function response event first yield function_response_event + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event + transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( @@ -532,7 +546,20 @@ class BaseLlmFlow(ABC): if auth_event: yield auth_event + # Always yield the function response event first yield function_response_event + + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( diff --git a/src/google/adk/flows/llm_flows/basic.py b/src/google/adk/flows/llm_flows/basic.py index c5dfbd1c..549c6d87 100644 --- a/src/google/adk/flows/llm_flows/basic.py +++ b/src/google/adk/flows/llm_flows/basic.py @@ -50,7 +50,11 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor): if agent.generate_content_config else types.GenerateContentConfig() ) - if agent.output_schema: + # Only set output_schema if no tools are specified. as of now, model don't + # support output_schema and tools together. we have a workaround to support + # both outoput_schema and tools at the same time. see + # _output_schema_processor.py for details + if agent.output_schema and not agent.tools: llm_request.set_output_schema(agent.output_schema) llm_request.live_connect_config.response_modalities = ( diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index 787a7679..5b398b52 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -14,10 +14,13 @@ """Implementation of single flow.""" +from __future__ import annotations + import logging from . import _code_execution from . import _nl_planning +from . import _output_schema_processor from . import basic from . import contents from . import identity @@ -50,6 +53,9 @@ class SingleFlow(BaseLlmFlow): # Code execution should be after the contents as it mutates the contents # to optimize data files. _code_execution.request_processor, + # Output schema processor add system instruction and set_model_response + # when both output_schema and tools are present. + _output_schema_processor.request_processor, ] self.response_processors += [ _nl_planning.response_processor, diff --git a/src/google/adk/tools/set_model_response_tool.py b/src/google/adk/tools/set_model_response_tool.py new file mode 100644 index 00000000..6b27d55c --- /dev/null +++ b/src/google/adk/tools/set_model_response_tool.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool for setting model response when using output_schema with other tools.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from google.genai import types +from pydantic import BaseModel +from typing_extensions import override + +from ._automatic_function_calling_util import build_function_declaration +from .base_tool import BaseTool +from .tool_context import ToolContext + +MODEL_JSON_RESPONSE_KEY = 'temp:__adk_model_response__' + + +class SetModelResponseTool(BaseTool): + """Internal tool used for output schema workaround. + + This tool allows the model to set its final response when output_schema + is configured alongside other tools. The model should use this tool to + provide its final structured response instead of outputting text directly. + """ + + def __init__(self, output_schema: type[BaseModel]): + """Initialize the tool with the expected output schema. + + Args: + output_schema: The pydantic model class defining the expected output + structure. + """ + self.output_schema = output_schema + + # Create a function that matches the output schema + def set_model_response() -> str: + """Set your final response using the required output schema. + + Use this tool to provide your final structured answer instead + of outputting text directly. + """ + return 'Response set successfully.' + + # Add the schema fields as parameters to the function dynamically + import inspect + + schema_fields = output_schema.model_fields + params = [] + for field_name, field_info in schema_fields.items(): + param = inspect.Parameter( + field_name, + inspect.Parameter.KEYWORD_ONLY, + annotation=field_info.annotation, + ) + params.append(param) + + # Create new signature with schema parameters + new_sig = inspect.Signature(parameters=params) + setattr(set_model_response, '__signature__', new_sig) + + self.func = set_model_response + + super().__init__( + name=self.func.__name__, + description=self.func.__doc__.strip() if self.func.__doc__ else '', + ) + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + """Gets the OpenAPI specification of this tool.""" + function_decl = types.FunctionDeclaration.model_validate( + build_function_declaration( + func=self.func, + ignore_params=[], + variant=self._api_variant, + ) + ) + return function_decl + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext # pylint: disable=unused-argument + ) -> dict[str, Any]: + """Process the model's response and return the validated dict. + + Args: + args: The structured response data matching the output schema. + tool_context: Tool execution context. + + Returns: + The validated response as dict. + """ + # Validate the input matches the expected schema + validated_response = self.output_schema.model_validate(args) + + # Return the validated dict directly + return validated_response.model_dump() diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 9b3a4abc..e62cf4e8 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -201,19 +201,18 @@ def test_output_schema_with_sub_agents_will_throw(): ) -def test_output_schema_with_tools_will_throw(): +def test_output_schema_with_tools_will_not_throw(): class Schema(BaseModel): pass def _a_tool(): pass - with pytest.raises(ValueError): - _ = LlmAgent( - name='test_agent', - output_schema=Schema, - tools=[_a_tool], - ) + LlmAgent( + name='test_agent', + output_schema=Schema, + tools=[_a_tool], + ) def test_before_model_callback(): diff --git a/tests/unittests/flows/llm_flows/test_basic_processor.py b/tests/unittests/flows/llm_flows/test_basic_processor.py new file mode 100644 index 00000000..770f3589 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_basic_processor.py @@ -0,0 +1,145 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for basic LLM request processor.""" + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.function_tool import FunctionTool +from pydantic import BaseModel +from pydantic import Field +import pytest + + +class OutputSchema(BaseModel): + """Test schema for output.""" + + name: str = Field(description='A name') + value: int = Field(description='A value') + + +def dummy_tool(query: str) -> str: + """A dummy tool for testing.""" + return f'Result: {query}' + + +async def _create_invocation_context(agent: LlmAgent) -> InvocationContext: + """Helper to create InvocationContext for testing.""" + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + return InvocationContext( + invocation_id='test-id', + agent=agent, + session=session, + session_service=session_service, + run_config=RunConfig(), + ) + + +class TestBasicLlmRequestProcessor: + """Test class for _BasicLlmRequestProcessor.""" + + @pytest.mark.asyncio + async def test_sets_output_schema_when_no_tools(self): + """Test that processor sets output_schema when agent has no tools.""" + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=OutputSchema, + tools=[], # No tools + ) + + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should have set response_schema since agent has no tools + assert llm_request.config.response_schema == OutputSchema + assert llm_request.config.response_mime_type == 'application/json' + + @pytest.mark.asyncio + async def test_skips_output_schema_when_tools_present(self): + """Test that processor skips output_schema when agent has tools.""" + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=OutputSchema, + tools=[FunctionTool(func=dummy_tool)], # Has tools + ) + + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should NOT have set response_schema since agent has tools + assert llm_request.config.response_schema is None + assert llm_request.config.response_mime_type != 'application/json' + + @pytest.mark.asyncio + async def test_no_output_schema_no_tools(self): + """Test that processor works normally when agent has no output_schema or tools.""" + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + # No output_schema, no tools + ) + + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should not have set anything + assert llm_request.config.response_schema is None + assert llm_request.config.response_mime_type != 'application/json' + + @pytest.mark.asyncio + async def test_sets_model_name(self): + """Test that processor sets the model name correctly.""" + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + ) + + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should have set the model name + assert llm_request.model == 'gemini-1.5-flash' diff --git a/tests/unittests/flows/llm_flows/test_output_schema_processor.py b/tests/unittests/flows/llm_flows/test_output_schema_processor.py new file mode 100644 index 00000000..42bfa880 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_output_schema_processor.py @@ -0,0 +1,409 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for output schema processor functionality.""" + +import json + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.flows.llm_flows.single_flow import SingleFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.function_tool import FunctionTool +from pydantic import BaseModel +from pydantic import Field +import pytest + + +class PersonSchema(BaseModel): + """Test schema for structured output.""" + + name: str = Field(description="A person's name") + age: int = Field(description="A person's age") + city: str = Field(description='The city they live in') + + +def dummy_tool(query: str) -> str: + """A dummy tool for testing.""" + return f'Searched for: {query}' + + +async def _create_invocation_context(agent: LlmAgent) -> InvocationContext: + """Helper to create InvocationContext for testing.""" + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + return InvocationContext( + invocation_id='test-id', + agent=agent, + session=session, + session_service=session_service, + run_config=RunConfig(), + ) + + +@pytest.mark.asyncio +async def test_output_schema_with_tools_validation_removed(): + """Test that LlmAgent now allows output_schema with tools.""" + # This should not raise an error anymore + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + + assert agent.output_schema == PersonSchema + assert len(agent.tools) == 1 + + +@pytest.mark.asyncio +async def test_basic_processor_skips_output_schema_with_tools(): + """Test that basic processor doesn't set output_schema when tools are present.""" + from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + + invocation_context = await _create_invocation_context(agent) + + llm_request = LlmRequest() + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should not have set response_schema since agent has tools + assert llm_request.config.response_schema is None + assert llm_request.config.response_mime_type != 'application/json' + + +@pytest.mark.asyncio +async def test_basic_processor_sets_output_schema_without_tools(): + """Test that basic processor still sets output_schema when no tools are present.""" + from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[], # No tools + ) + + invocation_context = await _create_invocation_context(agent) + + llm_request = LlmRequest() + processor = _BasicLlmRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should have set response_schema since agent has no tools + assert llm_request.config.response_schema == PersonSchema + assert llm_request.config.response_mime_type == 'application/json' + + +@pytest.mark.asyncio +async def test_output_schema_request_processor(): + """Test that output schema processor adds set_model_response tool.""" + from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + + invocation_context = await _create_invocation_context(agent) + + llm_request = LlmRequest() + processor = _OutputSchemaRequestProcessor() + + # Process the request + events = [] + async for event in processor.run_async(invocation_context, llm_request): + events.append(event) + + # Should have added set_model_response tool + assert 'set_model_response' in llm_request.tools_dict + + # Should have added instruction about using set_model_response + assert 'set_model_response' in llm_request.config.system_instruction + + +@pytest.mark.asyncio +async def test_set_model_response_tool(): + """Test the set_model_response tool functionality.""" + from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY + from google.adk.tools.set_model_response_tool import SetModelResponseTool + from google.adk.tools.tool_context import ToolContext + + tool = SetModelResponseTool(PersonSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Call the tool with valid data + result = await tool.run_async( + args={'name': 'John Doe', 'age': 30, 'city': 'New York'}, + tool_context=tool_context, + ) + + # Verify the tool now returns dict directly + assert result is not None + assert result['name'] == 'John Doe' + assert result['age'] == 30 + assert result['city'] == 'New York' + + # Check that the response is no longer stored in session state + stored_response = invocation_context.session.state.get( + MODEL_JSON_RESPONSE_KEY + ) + assert stored_response is None + + +@pytest.mark.asyncio +async def test_output_schema_helper_functions(): + """Test the helper functions for handling set_model_response.""" + from google.adk.events.event import Event + from google.adk.flows.llm_flows._output_schema_processor import create_final_model_response_event + from google.adk.flows.llm_flows._output_schema_processor import get_structured_model_response + from google.genai import types + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + + invocation_context = await _create_invocation_context(agent) + + # Test get_structured_model_response with a function response event + test_dict = {'name': 'Jane Smith', 'age': 25, 'city': 'Los Angeles'} + test_json = '{"name": "Jane Smith", "age": 25, "city": "Los Angeles"}' + + # Create a function response event with set_model_response + function_response_event = Event( + author='test_agent', + content=types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='set_model_response', response=test_dict + ) + ) + ], + ), + ) + + # Test get_structured_model_response function + extracted_json = get_structured_model_response(function_response_event) + assert extracted_json == test_json + + # Test create_final_model_response_event function + final_event = create_final_model_response_event(invocation_context, test_json) + assert final_event.author == 'test_agent' + assert final_event.content.role == 'model' + assert final_event.content.parts[0].text == test_json + + # Test get_structured_model_response with non-set_model_response function + other_function_response_event = Event( + author='test_agent', + content=types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='other_tool', response={'result': 'other response'} + ) + ) + ], + ), + ) + + extracted_json = get_structured_model_response(other_function_response_event) + assert extracted_json is None + + +@pytest.mark.asyncio +async def test_end_to_end_integration(): + """Test the complete output schema with tools integration.""" + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + + invocation_context = await _create_invocation_context(agent) + + # Create a flow and test the processors + flow = SingleFlow() + llm_request = LlmRequest() + + # Run all request processors + async for event in flow._preprocess_async(invocation_context, llm_request): + pass + + # Verify set_model_response tool was added + assert 'set_model_response' in llm_request.tools_dict + + # Verify instruction was added + assert 'set_model_response' in llm_request.config.system_instruction + + # Verify output_schema was NOT set on the model config + assert llm_request.config.response_schema is None + + +@pytest.mark.asyncio +async def test_flow_yields_both_events_for_set_model_response(): + """Test that the flow yields both function response and final model response events.""" + from google.adk.events.event import Event + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + from google.adk.tools.set_model_response_tool import SetModelResponseTool + from google.genai import types + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[], + ) + + invocation_context = await _create_invocation_context(agent) + flow = BaseLlmFlow() + + # Create a set_model_response tool and add it to the tools dict + set_response_tool = SetModelResponseTool(PersonSchema) + llm_request = LlmRequest() + llm_request.tools_dict['set_model_response'] = set_response_tool + + # Create a function call event (model calling the function) + function_call_event = Event( + author='test_agent', + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='set_model_response', + args={ + 'name': 'Test User', + 'age': 30, + 'city': 'Test City', + }, + ) + ) + ], + ), + ) + + # Test the postprocess function handling + events = [] + async for event in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(event) + + # Should yield exactly 2 events: function response + final model response + assert len(events) == 2 + + # First event should be the function response + first_event = events[0] + assert first_event.get_function_responses()[0].name == 'set_model_response' + # The response should be the dict returned by the tool + assert first_event.get_function_responses()[0].response == { + 'name': 'Test User', + 'age': 30, + 'city': 'Test City', + } + + # Second event should be the final model response with JSON + second_event = events[1] + assert second_event.author == 'test_agent' + assert second_event.content.role == 'model' + assert ( + second_event.content.parts[0].text + == '{"name": "Test User", "age": 30, "city": "Test City"}' + ) + + +@pytest.mark.asyncio +async def test_flow_yields_only_function_response_for_normal_tools(): + """Test that the flow yields only function response event for non-set_model_response tools.""" + from google.adk.events.event import Event + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + from google.genai import types + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + tools=[FunctionTool(func=dummy_tool)], + ) + + invocation_context = await _create_invocation_context(agent) + flow = BaseLlmFlow() + + # Create a dummy tool and add it to the tools dict + dummy_function_tool = FunctionTool(func=dummy_tool) + llm_request = LlmRequest() + llm_request.tools_dict['dummy_tool'] = dummy_function_tool + + # Create a function call event (model calling the dummy tool) + function_call_event = Event( + author='test_agent', + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='dummy_tool', args={'query': 'test query'} + ) + ) + ], + ), + ) + + # Test the postprocess function handling + events = [] + async for event in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(event) + + # Should yield exactly 1 event: just the function response + assert len(events) == 1 + + # Should be the function response from dummy_tool + first_event = events[0] + assert first_event.get_function_responses()[0].name == 'dummy_tool' + assert first_event.get_function_responses()[0].response == { + 'result': 'Searched for: test query' + } diff --git a/tests/unittests/tools/test_set_model_response_tool.py b/tests/unittests/tools/test_set_model_response_tool.py new file mode 100644 index 00000000..ca768a9e --- /dev/null +++ b/tests/unittests/tools/test_set_model_response_tool.py @@ -0,0 +1,276 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for SetModelResponseTool.""" + + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY +from google.adk.tools.set_model_response_tool import SetModelResponseTool +from google.adk.tools.tool_context import ToolContext +from pydantic import BaseModel +from pydantic import Field +from pydantic import ValidationError +import pytest + + +class PersonSchema(BaseModel): + """Test schema for structured output.""" + + name: str = Field(description="A person's name") + age: int = Field(description="A person's age") + city: str = Field(description='The city they live in') + + +class ComplexSchema(BaseModel): + """More complex test schema.""" + + id: int + title: str + tags: list[str] = Field(default_factory=list) + metadata: dict[str, str] = Field(default_factory=dict) + is_active: bool = True + + +async def _create_invocation_context(agent: LlmAgent) -> InvocationContext: + """Helper to create InvocationContext for testing.""" + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + return InvocationContext( + invocation_id='test-id', + agent=agent, + session=session, + session_service=session_service, + run_config=RunConfig(), + ) + + +def test_tool_initialization_simple_schema(): + """Test tool initialization with a simple schema.""" + tool = SetModelResponseTool(PersonSchema) + + assert tool.output_schema == PersonSchema + assert tool.name == 'set_model_response' + assert 'Set your final response' in tool.description + assert tool.func is not None + + +def test_tool_initialization_complex_schema(): + """Test tool initialization with a complex schema.""" + tool = SetModelResponseTool(ComplexSchema) + + assert tool.output_schema == ComplexSchema + assert tool.name == 'set_model_response' + assert tool.func is not None + + +def test_function_signature_generation(): + """Test that function signature is correctly generated from schema.""" + tool = SetModelResponseTool(PersonSchema) + + import inspect + + sig = inspect.signature(tool.func) + + # Check that parameters match schema fields + assert 'name' in sig.parameters + assert 'age' in sig.parameters + assert 'city' in sig.parameters + + # All parameters should be keyword-only + for param in sig.parameters.values(): + assert param.kind == inspect.Parameter.KEYWORD_ONLY + + +def test_get_declaration(): + """Test that tool declaration is properly generated.""" + tool = SetModelResponseTool(PersonSchema) + + declaration = tool._get_declaration() + + assert declaration is not None + assert declaration.name == 'set_model_response' + assert declaration.description is not None + + +@pytest.mark.asyncio +async def test_run_async_valid_data(): + """Test tool execution with valid data.""" + tool = SetModelResponseTool(PersonSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with valid data + result = await tool.run_async( + args={'name': 'Alice', 'age': 25, 'city': 'Seattle'}, + tool_context=tool_context, + ) + + # Verify the tool now returns dict directly + assert result is not None + assert result['name'] == 'Alice' + assert result['age'] == 25 + assert result['city'] == 'Seattle' + + # Verify data is no longer stored in session state (old behavior) + stored_response = invocation_context.session.state.get( + MODEL_JSON_RESPONSE_KEY + ) + assert stored_response is None + + +@pytest.mark.asyncio +async def test_run_async_complex_schema(): + """Test tool execution with complex schema.""" + tool = SetModelResponseTool(ComplexSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with complex data + result = await tool.run_async( + args={ + 'id': 123, + 'title': 'Test Item', + 'tags': ['tag1', 'tag2'], + 'metadata': {'key': 'value'}, + 'is_active': False, + }, + tool_context=tool_context, + ) + + # Verify the tool now returns dict directly + assert result is not None + assert result['id'] == 123 + assert result['title'] == 'Test Item' + assert result['tags'] == ['tag1', 'tag2'] + assert result['metadata'] == {'key': 'value'} + assert result['is_active'] is False + + # Verify data is no longer stored in session state (old behavior) + stored_response = invocation_context.session.state.get( + MODEL_JSON_RESPONSE_KEY + ) + assert stored_response is None + + +@pytest.mark.asyncio +async def test_run_async_validation_error(): + """Test tool execution with invalid data raises validation error.""" + tool = SetModelResponseTool(PersonSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with invalid data (wrong type for age) + with pytest.raises(ValidationError): + await tool.run_async( + args={'name': 'Bob', 'age': 'not_a_number', 'city': 'Portland'}, + tool_context=tool_context, + ) + + +@pytest.mark.asyncio +async def test_run_async_missing_required_field(): + """Test tool execution with missing required field.""" + tool = SetModelResponseTool(PersonSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # Execute with missing required field + with pytest.raises(ValidationError): + await tool.run_async( + args={'name': 'Charlie', 'city': 'Denver'}, # Missing age + tool_context=tool_context, + ) + + +@pytest.mark.asyncio +async def test_session_state_storage_key(): + """Test that response is no longer stored in session state.""" + tool = SetModelResponseTool(PersonSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + result = await tool.run_async( + args={'name': 'Diana', 'age': 35, 'city': 'Miami'}, + tool_context=tool_context, + ) + + # Verify response is returned directly, not stored in session state + assert result is not None + assert result['name'] == 'Diana' + assert result['age'] == 35 + assert result['city'] == 'Miami' + + # Verify session state is no longer used + assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state + + +@pytest.mark.asyncio +async def test_multiple_executions_return_latest(): + """Test that multiple executions return latest response independently.""" + tool = SetModelResponseTool(PersonSchema) + + agent = LlmAgent(name='test_agent', model='gemini-1.5-flash') + invocation_context = await _create_invocation_context(agent) + tool_context = ToolContext(invocation_context) + + # First execution + result1 = await tool.run_async( + args={'name': 'First', 'age': 20, 'city': 'City1'}, + tool_context=tool_context, + ) + + # Second execution should return its own response + result2 = await tool.run_async( + args={'name': 'Second', 'age': 30, 'city': 'City2'}, + tool_context=tool_context, + ) + + # Verify each execution returns its own dict + assert result1['name'] == 'First' + assert result1['age'] == 20 + assert result1['city'] == 'City1' + + assert result2['name'] == 'Second' + assert result2['age'] == 30 + assert result2['city'] == 'City2' + + # Verify session state is not used + assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state + + +def test_function_return_value_consistency(): + """Test that function return value matches run_async return value.""" + tool = SetModelResponseTool(PersonSchema) + + # Direct function call + direct_result = tool.func() + + # Both should return the same value + assert direct_result == 'Response set successfully.'