feat: Support both output_schema and tools at the same time in LlmAgent

1. Allow developers to specify output schema and tools together.
2. If both are specified, do the following:
  2.1 Do not set output schema on the model config
  2.2 Add a special tool called set_model_response(result)
  2.3 `result` has the same schema as the requested output_schema
  2.4 Instruct the model to use set_model_response() to output its final result, rather than output text directly.
  2.5 When the set_model_response() is called, ADK will extract its content and put it in a text part, so the client would treat it as the model response.

PiperOrigin-RevId: 792686011
This commit is contained in:
Xiang (Sean) Zhou
2025-08-08 10:54:50 -07:00
committed by Copybara-Service
parent b4ce3b12d1
commit af635674b5
10 changed files with 1098 additions and 14 deletions
-6
View File
@@ -499,12 +499,6 @@ class LlmAgent(BaseAgent):
' sub_agents must be empty to disable agent transfer.'
)
if self.tools:
raise ValueError(
f'Invalid config for agent {self.name}: if output_schema is set,'
' tools must be empty'
)
@field_validator('generate_content_config', mode='after')
@classmethod
def __validate_generate_content_config(
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Handles output schema when tools are also present."""
from __future__ import annotations
import json
from typing import AsyncGenerator
from typing_extensions import override
from ...agents.invocation_context import InvocationContext
from ...events.event import Event
from ...models.llm_request import LlmRequest
from ...tools.set_model_response_tool import SetModelResponseTool
from ._base_llm_processor import BaseLlmRequestProcessor
class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
"""Processor that handles output schema for agents with tools."""
@override
async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
agent = invocation_context.agent
if not isinstance(agent, LlmAgent):
return
# Check if we need the processor: output_schema + tools
if not agent.output_schema or not agent.tools:
return
# Add the set_model_response tool to handle structured output
set_response_tool = SetModelResponseTool(agent.output_schema)
llm_request.append_tools([set_response_tool])
# Add instruction about using the set_model_response tool
instruction = (
'IMPORTANT: You have access to other tools, but you must provide '
'your final response using the set_model_response tool with the '
'required structured format. After using any other tools needed '
'to complete the task, always call set_model_response with your '
'final answer in the specified schema format.'
)
llm_request.append_instructions([instruction])
return
yield # Generator requires yield statement in function body.
def create_final_model_response_event(
invocation_context: InvocationContext, json_response: str
) -> Event:
"""Create a final model response event from set_model_response JSON.
Args:
invocation_context: The invocation context.
json_response: The JSON response from set_model_response tool.
Returns:
A new Event that looks like a normal model response.
"""
from google.genai import types
# Create a proper model response event
final_event = Event(author=invocation_context.agent.name)
final_event.content = types.Content(
role='model', parts=[types.Part(text=json_response)]
)
return final_event
def get_structured_model_response(function_response_event: Event) -> str | None:
"""Check if function response contains set_model_response and extract JSON.
Args:
function_response_event: The function response event to check.
Returns:
JSON response string if set_model_response was called, None otherwise.
"""
if (
not function_response_event
or not function_response_event.get_function_responses()
):
return None
for func_response in function_response_event.get_function_responses():
if func_response.name == 'set_model_response':
# Convert dict to JSON string
return json.dumps(func_response.response)
return None
# Export the processors
request_processor = _OutputSchemaRequestProcessor()
@@ -28,6 +28,7 @@ from google.genai import types
from websockets.exceptions import ConnectionClosed
from websockets.exceptions import ConnectionClosedOK
from . import _output_schema_processor
from . import functions
from ...agents.base_agent import BaseAgent
from ...agents.callback_context import CallbackContext
@@ -500,8 +501,21 @@ class BaseLlmFlow(ABC):
function_response_event = await functions.handle_function_calls_live(
invocation_context, model_response_event, llm_request.tools_dict
)
# Always yield the function response event first
yield function_response_event
# Check if this is a set_model_response function response
if json_response := _output_schema_processor.get_structured_model_response(
function_response_event
):
# Create and yield a final model response event
final_event = (
_output_schema_processor.create_final_model_response_event(
invocation_context, json_response
)
)
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
@@ -532,7 +546,20 @@ class BaseLlmFlow(ABC):
if auth_event:
yield auth_event
# Always yield the function response event first
yield function_response_event
# Check if this is a set_model_response function response
if json_response := _output_schema_processor.get_structured_model_response(
function_response_event
):
# Create and yield a final model response event
final_event = (
_output_schema_processor.create_final_model_response_event(
invocation_context, json_response
)
)
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
agent_to_run = self._get_agent_to_run(
+5 -1
View File
@@ -50,7 +50,11 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
if agent.generate_content_config
else types.GenerateContentConfig()
)
if agent.output_schema:
# Only set output_schema if no tools are specified. as of now, model don't
# support output_schema and tools together. we have a workaround to support
# both outoput_schema and tools at the same time. see
# _output_schema_processor.py for details
if agent.output_schema and not agent.tools:
llm_request.set_output_schema(agent.output_schema)
llm_request.live_connect_config.response_modalities = (
@@ -14,10 +14,13 @@
"""Implementation of single flow."""
from __future__ import annotations
import logging
from . import _code_execution
from . import _nl_planning
from . import _output_schema_processor
from . import basic
from . import contents
from . import identity
@@ -50,6 +53,9 @@ class SingleFlow(BaseLlmFlow):
# Code execution should be after the contents as it mutates the contents
# to optimize data files.
_code_execution.request_processor,
# Output schema processor add system instruction and set_model_response
# when both output_schema and tools are present.
_output_schema_processor.request_processor,
]
self.response_processors += [
_nl_planning.response_processor,
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool for setting model response when using output_schema with other tools."""
from __future__ import annotations
from typing import Any
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from typing_extensions import override
from ._automatic_function_calling_util import build_function_declaration
from .base_tool import BaseTool
from .tool_context import ToolContext
MODEL_JSON_RESPONSE_KEY = 'temp:__adk_model_response__'
class SetModelResponseTool(BaseTool):
"""Internal tool used for output schema workaround.
This tool allows the model to set its final response when output_schema
is configured alongside other tools. The model should use this tool to
provide its final structured response instead of outputting text directly.
"""
def __init__(self, output_schema: type[BaseModel]):
"""Initialize the tool with the expected output schema.
Args:
output_schema: The pydantic model class defining the expected output
structure.
"""
self.output_schema = output_schema
# Create a function that matches the output schema
def set_model_response() -> str:
"""Set your final response using the required output schema.
Use this tool to provide your final structured answer instead
of outputting text directly.
"""
return 'Response set successfully.'
# Add the schema fields as parameters to the function dynamically
import inspect
schema_fields = output_schema.model_fields
params = []
for field_name, field_info in schema_fields.items():
param = inspect.Parameter(
field_name,
inspect.Parameter.KEYWORD_ONLY,
annotation=field_info.annotation,
)
params.append(param)
# Create new signature with schema parameters
new_sig = inspect.Signature(parameters=params)
setattr(set_model_response, '__signature__', new_sig)
self.func = set_model_response
super().__init__(
name=self.func.__name__,
description=self.func.__doc__.strip() if self.func.__doc__ else '',
)
@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
"""Gets the OpenAPI specification of this tool."""
function_decl = types.FunctionDeclaration.model_validate(
build_function_declaration(
func=self.func,
ignore_params=[],
variant=self._api_variant,
)
)
return function_decl
@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext # pylint: disable=unused-argument
) -> dict[str, Any]:
"""Process the model's response and return the validated dict.
Args:
args: The structured response data matching the output schema.
tool_context: Tool execution context.
Returns:
The validated response as dict.
"""
# Validate the input matches the expected schema
validated_response = self.output_schema.model_validate(args)
# Return the validated dict directly
return validated_response.model_dump()
@@ -201,19 +201,18 @@ def test_output_schema_with_sub_agents_will_throw():
)
def test_output_schema_with_tools_will_throw():
def test_output_schema_with_tools_will_not_throw():
class Schema(BaseModel):
pass
def _a_tool():
pass
with pytest.raises(ValueError):
_ = LlmAgent(
name='test_agent',
output_schema=Schema,
tools=[_a_tool],
)
LlmAgent(
name='test_agent',
output_schema=Schema,
tools=[_a_tool],
)
def test_before_model_callback():
@@ -0,0 +1,145 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for basic LLM request processor."""
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor
from google.adk.models.llm_request import LlmRequest
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.function_tool import FunctionTool
from pydantic import BaseModel
from pydantic import Field
import pytest
class OutputSchema(BaseModel):
"""Test schema for output."""
name: str = Field(description='A name')
value: int = Field(description='A value')
def dummy_tool(query: str) -> str:
"""A dummy tool for testing."""
return f'Result: {query}'
async def _create_invocation_context(agent: LlmAgent) -> InvocationContext:
"""Helper to create InvocationContext for testing."""
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
invocation_id='test-id',
agent=agent,
session=session,
session_service=session_service,
run_config=RunConfig(),
)
class TestBasicLlmRequestProcessor:
"""Test class for _BasicLlmRequestProcessor."""
@pytest.mark.asyncio
async def test_sets_output_schema_when_no_tools(self):
"""Test that processor sets output_schema when agent has no tools."""
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=OutputSchema,
tools=[], # No tools
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _BasicLlmRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should have set response_schema since agent has no tools
assert llm_request.config.response_schema == OutputSchema
assert llm_request.config.response_mime_type == 'application/json'
@pytest.mark.asyncio
async def test_skips_output_schema_when_tools_present(self):
"""Test that processor skips output_schema when agent has tools."""
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=OutputSchema,
tools=[FunctionTool(func=dummy_tool)], # Has tools
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _BasicLlmRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should NOT have set response_schema since agent has tools
assert llm_request.config.response_schema is None
assert llm_request.config.response_mime_type != 'application/json'
@pytest.mark.asyncio
async def test_no_output_schema_no_tools(self):
"""Test that processor works normally when agent has no output_schema or tools."""
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
# No output_schema, no tools
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _BasicLlmRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should not have set anything
assert llm_request.config.response_schema is None
assert llm_request.config.response_mime_type != 'application/json'
@pytest.mark.asyncio
async def test_sets_model_name(self):
"""Test that processor sets the model name correctly."""
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _BasicLlmRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should have set the model name
assert llm_request.model == 'gemini-1.5-flash'
@@ -0,0 +1,409 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for output schema processor functionality."""
import json
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.flows.llm_flows.single_flow import SingleFlow
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.function_tool import FunctionTool
from pydantic import BaseModel
from pydantic import Field
import pytest
class PersonSchema(BaseModel):
"""Test schema for structured output."""
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")
city: str = Field(description='The city they live in')
def dummy_tool(query: str) -> str:
"""A dummy tool for testing."""
return f'Searched for: {query}'
async def _create_invocation_context(agent: LlmAgent) -> InvocationContext:
"""Helper to create InvocationContext for testing."""
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
invocation_id='test-id',
agent=agent,
session=session,
session_service=session_service,
run_config=RunConfig(),
)
@pytest.mark.asyncio
async def test_output_schema_with_tools_validation_removed():
"""Test that LlmAgent now allows output_schema with tools."""
# This should not raise an error anymore
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[FunctionTool(func=dummy_tool)],
)
assert agent.output_schema == PersonSchema
assert len(agent.tools) == 1
@pytest.mark.asyncio
async def test_basic_processor_skips_output_schema_with_tools():
"""Test that basic processor doesn't set output_schema when tools are present."""
from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[FunctionTool(func=dummy_tool)],
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _BasicLlmRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should not have set response_schema since agent has tools
assert llm_request.config.response_schema is None
assert llm_request.config.response_mime_type != 'application/json'
@pytest.mark.asyncio
async def test_basic_processor_sets_output_schema_without_tools():
"""Test that basic processor still sets output_schema when no tools are present."""
from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[], # No tools
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _BasicLlmRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should have set response_schema since agent has no tools
assert llm_request.config.response_schema == PersonSchema
assert llm_request.config.response_mime_type == 'application/json'
@pytest.mark.asyncio
async def test_output_schema_request_processor():
"""Test that output schema processor adds set_model_response tool."""
from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[FunctionTool(func=dummy_tool)],
)
invocation_context = await _create_invocation_context(agent)
llm_request = LlmRequest()
processor = _OutputSchemaRequestProcessor()
# Process the request
events = []
async for event in processor.run_async(invocation_context, llm_request):
events.append(event)
# Should have added set_model_response tool
assert 'set_model_response' in llm_request.tools_dict
# Should have added instruction about using set_model_response
assert 'set_model_response' in llm_request.config.system_instruction
@pytest.mark.asyncio
async def test_set_model_response_tool():
"""Test the set_model_response tool functionality."""
from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY
from google.adk.tools.set_model_response_tool import SetModelResponseTool
from google.adk.tools.tool_context import ToolContext
tool = SetModelResponseTool(PersonSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
# Call the tool with valid data
result = await tool.run_async(
args={'name': 'John Doe', 'age': 30, 'city': 'New York'},
tool_context=tool_context,
)
# Verify the tool now returns dict directly
assert result is not None
assert result['name'] == 'John Doe'
assert result['age'] == 30
assert result['city'] == 'New York'
# Check that the response is no longer stored in session state
stored_response = invocation_context.session.state.get(
MODEL_JSON_RESPONSE_KEY
)
assert stored_response is None
@pytest.mark.asyncio
async def test_output_schema_helper_functions():
"""Test the helper functions for handling set_model_response."""
from google.adk.events.event import Event
from google.adk.flows.llm_flows._output_schema_processor import create_final_model_response_event
from google.adk.flows.llm_flows._output_schema_processor import get_structured_model_response
from google.genai import types
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[FunctionTool(func=dummy_tool)],
)
invocation_context = await _create_invocation_context(agent)
# Test get_structured_model_response with a function response event
test_dict = {'name': 'Jane Smith', 'age': 25, 'city': 'Los Angeles'}
test_json = '{"name": "Jane Smith", "age": 25, "city": "Los Angeles"}'
# Create a function response event with set_model_response
function_response_event = Event(
author='test_agent',
content=types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
name='set_model_response', response=test_dict
)
)
],
),
)
# Test get_structured_model_response function
extracted_json = get_structured_model_response(function_response_event)
assert extracted_json == test_json
# Test create_final_model_response_event function
final_event = create_final_model_response_event(invocation_context, test_json)
assert final_event.author == 'test_agent'
assert final_event.content.role == 'model'
assert final_event.content.parts[0].text == test_json
# Test get_structured_model_response with non-set_model_response function
other_function_response_event = Event(
author='test_agent',
content=types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
name='other_tool', response={'result': 'other response'}
)
)
],
),
)
extracted_json = get_structured_model_response(other_function_response_event)
assert extracted_json is None
@pytest.mark.asyncio
async def test_end_to_end_integration():
"""Test the complete output schema with tools integration."""
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[FunctionTool(func=dummy_tool)],
)
invocation_context = await _create_invocation_context(agent)
# Create a flow and test the processors
flow = SingleFlow()
llm_request = LlmRequest()
# Run all request processors
async for event in flow._preprocess_async(invocation_context, llm_request):
pass
# Verify set_model_response tool was added
assert 'set_model_response' in llm_request.tools_dict
# Verify instruction was added
assert 'set_model_response' in llm_request.config.system_instruction
# Verify output_schema was NOT set on the model config
assert llm_request.config.response_schema is None
@pytest.mark.asyncio
async def test_flow_yields_both_events_for_set_model_response():
"""Test that the flow yields both function response and final model response events."""
from google.adk.events.event import Event
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.tools.set_model_response_tool import SetModelResponseTool
from google.genai import types
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
output_schema=PersonSchema,
tools=[],
)
invocation_context = await _create_invocation_context(agent)
flow = BaseLlmFlow()
# Create a set_model_response tool and add it to the tools dict
set_response_tool = SetModelResponseTool(PersonSchema)
llm_request = LlmRequest()
llm_request.tools_dict['set_model_response'] = set_response_tool
# Create a function call event (model calling the function)
function_call_event = Event(
author='test_agent',
content=types.Content(
role='model',
parts=[
types.Part(
function_call=types.FunctionCall(
name='set_model_response',
args={
'name': 'Test User',
'age': 30,
'city': 'Test City',
},
)
)
],
),
)
# Test the postprocess function handling
events = []
async for event in flow._postprocess_handle_function_calls_async(
invocation_context, function_call_event, llm_request
):
events.append(event)
# Should yield exactly 2 events: function response + final model response
assert len(events) == 2
# First event should be the function response
first_event = events[0]
assert first_event.get_function_responses()[0].name == 'set_model_response'
# The response should be the dict returned by the tool
assert first_event.get_function_responses()[0].response == {
'name': 'Test User',
'age': 30,
'city': 'Test City',
}
# Second event should be the final model response with JSON
second_event = events[1]
assert second_event.author == 'test_agent'
assert second_event.content.role == 'model'
assert (
second_event.content.parts[0].text
== '{"name": "Test User", "age": 30, "city": "Test City"}'
)
@pytest.mark.asyncio
async def test_flow_yields_only_function_response_for_normal_tools():
"""Test that the flow yields only function response event for non-set_model_response tools."""
from google.adk.events.event import Event
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.genai import types
agent = LlmAgent(
name='test_agent',
model='gemini-1.5-flash',
tools=[FunctionTool(func=dummy_tool)],
)
invocation_context = await _create_invocation_context(agent)
flow = BaseLlmFlow()
# Create a dummy tool and add it to the tools dict
dummy_function_tool = FunctionTool(func=dummy_tool)
llm_request = LlmRequest()
llm_request.tools_dict['dummy_tool'] = dummy_function_tool
# Create a function call event (model calling the dummy tool)
function_call_event = Event(
author='test_agent',
content=types.Content(
role='model',
parts=[
types.Part(
function_call=types.FunctionCall(
name='dummy_tool', args={'query': 'test query'}
)
)
],
),
)
# Test the postprocess function handling
events = []
async for event in flow._postprocess_handle_function_calls_async(
invocation_context, function_call_event, llm_request
):
events.append(event)
# Should yield exactly 1 event: just the function response
assert len(events) == 1
# Should be the function response from dummy_tool
first_event = events[0]
assert first_event.get_function_responses()[0].name == 'dummy_tool'
assert first_event.get_function_responses()[0].response == {
'result': 'Searched for: test query'
}
@@ -0,0 +1,276 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for SetModelResponseTool."""
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY
from google.adk.tools.set_model_response_tool import SetModelResponseTool
from google.adk.tools.tool_context import ToolContext
from pydantic import BaseModel
from pydantic import Field
from pydantic import ValidationError
import pytest
class PersonSchema(BaseModel):
"""Test schema for structured output."""
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")
city: str = Field(description='The city they live in')
class ComplexSchema(BaseModel):
"""More complex test schema."""
id: int
title: str
tags: list[str] = Field(default_factory=list)
metadata: dict[str, str] = Field(default_factory=dict)
is_active: bool = True
async def _create_invocation_context(agent: LlmAgent) -> InvocationContext:
"""Helper to create InvocationContext for testing."""
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
return InvocationContext(
invocation_id='test-id',
agent=agent,
session=session,
session_service=session_service,
run_config=RunConfig(),
)
def test_tool_initialization_simple_schema():
"""Test tool initialization with a simple schema."""
tool = SetModelResponseTool(PersonSchema)
assert tool.output_schema == PersonSchema
assert tool.name == 'set_model_response'
assert 'Set your final response' in tool.description
assert tool.func is not None
def test_tool_initialization_complex_schema():
"""Test tool initialization with a complex schema."""
tool = SetModelResponseTool(ComplexSchema)
assert tool.output_schema == ComplexSchema
assert tool.name == 'set_model_response'
assert tool.func is not None
def test_function_signature_generation():
"""Test that function signature is correctly generated from schema."""
tool = SetModelResponseTool(PersonSchema)
import inspect
sig = inspect.signature(tool.func)
# Check that parameters match schema fields
assert 'name' in sig.parameters
assert 'age' in sig.parameters
assert 'city' in sig.parameters
# All parameters should be keyword-only
for param in sig.parameters.values():
assert param.kind == inspect.Parameter.KEYWORD_ONLY
def test_get_declaration():
"""Test that tool declaration is properly generated."""
tool = SetModelResponseTool(PersonSchema)
declaration = tool._get_declaration()
assert declaration is not None
assert declaration.name == 'set_model_response'
assert declaration.description is not None
@pytest.mark.asyncio
async def test_run_async_valid_data():
"""Test tool execution with valid data."""
tool = SetModelResponseTool(PersonSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
# Execute with valid data
result = await tool.run_async(
args={'name': 'Alice', 'age': 25, 'city': 'Seattle'},
tool_context=tool_context,
)
# Verify the tool now returns dict directly
assert result is not None
assert result['name'] == 'Alice'
assert result['age'] == 25
assert result['city'] == 'Seattle'
# Verify data is no longer stored in session state (old behavior)
stored_response = invocation_context.session.state.get(
MODEL_JSON_RESPONSE_KEY
)
assert stored_response is None
@pytest.mark.asyncio
async def test_run_async_complex_schema():
"""Test tool execution with complex schema."""
tool = SetModelResponseTool(ComplexSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
# Execute with complex data
result = await tool.run_async(
args={
'id': 123,
'title': 'Test Item',
'tags': ['tag1', 'tag2'],
'metadata': {'key': 'value'},
'is_active': False,
},
tool_context=tool_context,
)
# Verify the tool now returns dict directly
assert result is not None
assert result['id'] == 123
assert result['title'] == 'Test Item'
assert result['tags'] == ['tag1', 'tag2']
assert result['metadata'] == {'key': 'value'}
assert result['is_active'] is False
# Verify data is no longer stored in session state (old behavior)
stored_response = invocation_context.session.state.get(
MODEL_JSON_RESPONSE_KEY
)
assert stored_response is None
@pytest.mark.asyncio
async def test_run_async_validation_error():
"""Test tool execution with invalid data raises validation error."""
tool = SetModelResponseTool(PersonSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
# Execute with invalid data (wrong type for age)
with pytest.raises(ValidationError):
await tool.run_async(
args={'name': 'Bob', 'age': 'not_a_number', 'city': 'Portland'},
tool_context=tool_context,
)
@pytest.mark.asyncio
async def test_run_async_missing_required_field():
"""Test tool execution with missing required field."""
tool = SetModelResponseTool(PersonSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
# Execute with missing required field
with pytest.raises(ValidationError):
await tool.run_async(
args={'name': 'Charlie', 'city': 'Denver'}, # Missing age
tool_context=tool_context,
)
@pytest.mark.asyncio
async def test_session_state_storage_key():
"""Test that response is no longer stored in session state."""
tool = SetModelResponseTool(PersonSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
result = await tool.run_async(
args={'name': 'Diana', 'age': 35, 'city': 'Miami'},
tool_context=tool_context,
)
# Verify response is returned directly, not stored in session state
assert result is not None
assert result['name'] == 'Diana'
assert result['age'] == 35
assert result['city'] == 'Miami'
# Verify session state is no longer used
assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state
@pytest.mark.asyncio
async def test_multiple_executions_return_latest():
"""Test that multiple executions return latest response independently."""
tool = SetModelResponseTool(PersonSchema)
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
invocation_context = await _create_invocation_context(agent)
tool_context = ToolContext(invocation_context)
# First execution
result1 = await tool.run_async(
args={'name': 'First', 'age': 20, 'city': 'City1'},
tool_context=tool_context,
)
# Second execution should return its own response
result2 = await tool.run_async(
args={'name': 'Second', 'age': 30, 'city': 'City2'},
tool_context=tool_context,
)
# Verify each execution returns its own dict
assert result1['name'] == 'First'
assert result1['age'] == 20
assert result1['city'] == 'City1'
assert result2['name'] == 'Second'
assert result2['age'] == 30
assert result2['city'] == 'City2'
# Verify session state is not used
assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state
def test_function_return_value_consistency():
"""Test that function return value matches run_async return value."""
tool = SetModelResponseTool(PersonSchema)
# Direct function call
direct_result = tool.func()
# Both should return the same value
assert direct_result == 'Response set successfully.'