You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support both output_schema and tools at the same time in LlmAgent
1. Allow developers to specify output schema and tools together. 2. If both are specified, do the following: 2.1 Do not set output schema on the model config 2.2 Add a special tool called set_model_response(result) 2.3 `result` has the same schema as the requested output_schema 2.4 Instruct the model to use set_model_response() to output its final result, rather than output text directly. 2.5 When the set_model_response() is called, ADK will extract its content and put it in a text part, so the client would treat it as the model response. PiperOrigin-RevId: 792686011
This commit is contained in:
committed by
Copybara-Service
parent
b4ce3b12d1
commit
af635674b5
@@ -499,12 +499,6 @@ class LlmAgent(BaseAgent):
|
||||
' sub_agents must be empty to disable agent transfer.'
|
||||
)
|
||||
|
||||
if self.tools:
|
||||
raise ValueError(
|
||||
f'Invalid config for agent {self.name}: if output_schema is set,'
|
||||
' tools must be empty'
|
||||
)
|
||||
|
||||
@field_validator('generate_content_config', mode='after')
|
||||
@classmethod
|
||||
def __validate_generate_content_config(
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Handles output schema when tools are also present."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.invocation_context import InvocationContext
|
||||
from ...events.event import Event
|
||||
from ...models.llm_request import LlmRequest
|
||||
from ...tools.set_model_response_tool import SetModelResponseTool
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
|
||||
|
||||
class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Processor that handles output schema for agents with tools."""
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
agent = invocation_context.agent
|
||||
if not isinstance(agent, LlmAgent):
|
||||
return
|
||||
|
||||
# Check if we need the processor: output_schema + tools
|
||||
if not agent.output_schema or not agent.tools:
|
||||
return
|
||||
|
||||
# Add the set_model_response tool to handle structured output
|
||||
set_response_tool = SetModelResponseTool(agent.output_schema)
|
||||
llm_request.append_tools([set_response_tool])
|
||||
|
||||
# Add instruction about using the set_model_response tool
|
||||
instruction = (
|
||||
'IMPORTANT: You have access to other tools, but you must provide '
|
||||
'your final response using the set_model_response tool with the '
|
||||
'required structured format. After using any other tools needed '
|
||||
'to complete the task, always call set_model_response with your '
|
||||
'final answer in the specified schema format.'
|
||||
)
|
||||
llm_request.append_instructions([instruction])
|
||||
|
||||
return
|
||||
yield # Generator requires yield statement in function body.
|
||||
|
||||
|
||||
def create_final_model_response_event(
|
||||
invocation_context: InvocationContext, json_response: str
|
||||
) -> Event:
|
||||
"""Create a final model response event from set_model_response JSON.
|
||||
|
||||
Args:
|
||||
invocation_context: The invocation context.
|
||||
json_response: The JSON response from set_model_response tool.
|
||||
|
||||
Returns:
|
||||
A new Event that looks like a normal model response.
|
||||
"""
|
||||
from google.genai import types
|
||||
|
||||
# Create a proper model response event
|
||||
final_event = Event(author=invocation_context.agent.name)
|
||||
final_event.content = types.Content(
|
||||
role='model', parts=[types.Part(text=json_response)]
|
||||
)
|
||||
return final_event
|
||||
|
||||
|
||||
def get_structured_model_response(function_response_event: Event) -> str | None:
|
||||
"""Check if function response contains set_model_response and extract JSON.
|
||||
|
||||
Args:
|
||||
function_response_event: The function response event to check.
|
||||
|
||||
Returns:
|
||||
JSON response string if set_model_response was called, None otherwise.
|
||||
"""
|
||||
if (
|
||||
not function_response_event
|
||||
or not function_response_event.get_function_responses()
|
||||
):
|
||||
return None
|
||||
|
||||
for func_response in function_response_event.get_function_responses():
|
||||
if func_response.name == 'set_model_response':
|
||||
# Convert dict to JSON string
|
||||
return json.dumps(func_response.response)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Export the processors
|
||||
request_processor = _OutputSchemaRequestProcessor()
|
||||
@@ -28,6 +28,7 @@ from google.genai import types
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from . import _output_schema_processor
|
||||
from . import functions
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...agents.callback_context import CallbackContext
|
||||
@@ -500,8 +501,21 @@ class BaseLlmFlow(ABC):
|
||||
function_response_event = await functions.handle_function_calls_live(
|
||||
invocation_context, model_response_event, llm_request.tools_dict
|
||||
)
|
||||
# Always yield the function response event first
|
||||
yield function_response_event
|
||||
|
||||
# Check if this is a set_model_response function response
|
||||
if json_response := _output_schema_processor.get_structured_model_response(
|
||||
function_response_event
|
||||
):
|
||||
# Create and yield a final model response event
|
||||
final_event = (
|
||||
_output_schema_processor.create_final_model_response_event(
|
||||
invocation_context, json_response
|
||||
)
|
||||
)
|
||||
yield final_event
|
||||
|
||||
transfer_to_agent = function_response_event.actions.transfer_to_agent
|
||||
if transfer_to_agent:
|
||||
agent_to_run = self._get_agent_to_run(
|
||||
@@ -532,7 +546,20 @@ class BaseLlmFlow(ABC):
|
||||
if auth_event:
|
||||
yield auth_event
|
||||
|
||||
# Always yield the function response event first
|
||||
yield function_response_event
|
||||
|
||||
# Check if this is a set_model_response function response
|
||||
if json_response := _output_schema_processor.get_structured_model_response(
|
||||
function_response_event
|
||||
):
|
||||
# Create and yield a final model response event
|
||||
final_event = (
|
||||
_output_schema_processor.create_final_model_response_event(
|
||||
invocation_context, json_response
|
||||
)
|
||||
)
|
||||
yield final_event
|
||||
transfer_to_agent = function_response_event.actions.transfer_to_agent
|
||||
if transfer_to_agent:
|
||||
agent_to_run = self._get_agent_to_run(
|
||||
|
||||
@@ -50,7 +50,11 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
if agent.generate_content_config
|
||||
else types.GenerateContentConfig()
|
||||
)
|
||||
if agent.output_schema:
|
||||
# Only set output_schema if no tools are specified. as of now, model don't
|
||||
# support output_schema and tools together. we have a workaround to support
|
||||
# both outoput_schema and tools at the same time. see
|
||||
# _output_schema_processor.py for details
|
||||
if agent.output_schema and not agent.tools:
|
||||
llm_request.set_output_schema(agent.output_schema)
|
||||
|
||||
llm_request.live_connect_config.response_modalities = (
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
|
||||
"""Implementation of single flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from . import _code_execution
|
||||
from . import _nl_planning
|
||||
from . import _output_schema_processor
|
||||
from . import basic
|
||||
from . import contents
|
||||
from . import identity
|
||||
@@ -50,6 +53,9 @@ class SingleFlow(BaseLlmFlow):
|
||||
# Code execution should be after the contents as it mutates the contents
|
||||
# to optimize data files.
|
||||
_code_execution.request_processor,
|
||||
# Output schema processor add system instruction and set_model_response
|
||||
# when both output_schema and tools are present.
|
||||
_output_schema_processor.request_processor,
|
||||
]
|
||||
self.response_processors += [
|
||||
_nl_planning.response_processor,
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tool for setting model response when using output_schema with other tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from ._automatic_function_calling_util import build_function_declaration
|
||||
from .base_tool import BaseTool
|
||||
from .tool_context import ToolContext
|
||||
|
||||
MODEL_JSON_RESPONSE_KEY = 'temp:__adk_model_response__'
|
||||
|
||||
|
||||
class SetModelResponseTool(BaseTool):
|
||||
"""Internal tool used for output schema workaround.
|
||||
|
||||
This tool allows the model to set its final response when output_schema
|
||||
is configured alongside other tools. The model should use this tool to
|
||||
provide its final structured response instead of outputting text directly.
|
||||
"""
|
||||
|
||||
def __init__(self, output_schema: type[BaseModel]):
|
||||
"""Initialize the tool with the expected output schema.
|
||||
|
||||
Args:
|
||||
output_schema: The pydantic model class defining the expected output
|
||||
structure.
|
||||
"""
|
||||
self.output_schema = output_schema
|
||||
|
||||
# Create a function that matches the output schema
|
||||
def set_model_response() -> str:
|
||||
"""Set your final response using the required output schema.
|
||||
|
||||
Use this tool to provide your final structured answer instead
|
||||
of outputting text directly.
|
||||
"""
|
||||
return 'Response set successfully.'
|
||||
|
||||
# Add the schema fields as parameters to the function dynamically
|
||||
import inspect
|
||||
|
||||
schema_fields = output_schema.model_fields
|
||||
params = []
|
||||
for field_name, field_info in schema_fields.items():
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
annotation=field_info.annotation,
|
||||
)
|
||||
params.append(param)
|
||||
|
||||
# Create new signature with schema parameters
|
||||
new_sig = inspect.Signature(parameters=params)
|
||||
setattr(set_model_response, '__signature__', new_sig)
|
||||
|
||||
self.func = set_model_response
|
||||
|
||||
super().__init__(
|
||||
name=self.func.__name__,
|
||||
description=self.func.__doc__.strip() if self.func.__doc__ else '',
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
"""Gets the OpenAPI specification of this tool."""
|
||||
function_decl = types.FunctionDeclaration.model_validate(
|
||||
build_function_declaration(
|
||||
func=self.func,
|
||||
ignore_params=[],
|
||||
variant=self._api_variant,
|
||||
)
|
||||
)
|
||||
return function_decl
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext # pylint: disable=unused-argument
|
||||
) -> dict[str, Any]:
|
||||
"""Process the model's response and return the validated dict.
|
||||
|
||||
Args:
|
||||
args: The structured response data matching the output schema.
|
||||
tool_context: Tool execution context.
|
||||
|
||||
Returns:
|
||||
The validated response as dict.
|
||||
"""
|
||||
# Validate the input matches the expected schema
|
||||
validated_response = self.output_schema.model_validate(args)
|
||||
|
||||
# Return the validated dict directly
|
||||
return validated_response.model_dump()
|
||||
@@ -201,19 +201,18 @@ def test_output_schema_with_sub_agents_will_throw():
|
||||
)
|
||||
|
||||
|
||||
def test_output_schema_with_tools_will_throw():
|
||||
def test_output_schema_with_tools_will_not_throw():
|
||||
class Schema(BaseModel):
|
||||
pass
|
||||
|
||||
def _a_tool():
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = LlmAgent(
|
||||
name='test_agent',
|
||||
output_schema=Schema,
|
||||
tools=[_a_tool],
|
||||
)
|
||||
LlmAgent(
|
||||
name='test_agent',
|
||||
output_schema=Schema,
|
||||
tools=[_a_tool],
|
||||
)
|
||||
|
||||
|
||||
def test_before_model_callback():
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for basic LLM request processor."""
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
"""Test schema for output."""
|
||||
|
||||
name: str = Field(description='A name')
|
||||
value: int = Field(description='A value')
|
||||
|
||||
|
||||
def dummy_tool(query: str) -> str:
|
||||
"""A dummy tool for testing."""
|
||||
return f'Result: {query}'
|
||||
|
||||
|
||||
async def _create_invocation_context(agent: LlmAgent) -> InvocationContext:
|
||||
"""Helper to create InvocationContext for testing."""
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id='test-id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
class TestBasicLlmRequestProcessor:
|
||||
"""Test class for _BasicLlmRequestProcessor."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_output_schema_when_no_tools(self):
|
||||
"""Test that processor sets output_schema when agent has no tools."""
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=OutputSchema,
|
||||
tools=[], # No tools
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
llm_request = LlmRequest()
|
||||
processor = _BasicLlmRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should have set response_schema since agent has no tools
|
||||
assert llm_request.config.response_schema == OutputSchema
|
||||
assert llm_request.config.response_mime_type == 'application/json'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_output_schema_when_tools_present(self):
|
||||
"""Test that processor skips output_schema when agent has tools."""
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=OutputSchema,
|
||||
tools=[FunctionTool(func=dummy_tool)], # Has tools
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
llm_request = LlmRequest()
|
||||
processor = _BasicLlmRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should NOT have set response_schema since agent has tools
|
||||
assert llm_request.config.response_schema is None
|
||||
assert llm_request.config.response_mime_type != 'application/json'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_output_schema_no_tools(self):
|
||||
"""Test that processor works normally when agent has no output_schema or tools."""
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
# No output_schema, no tools
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
llm_request = LlmRequest()
|
||||
processor = _BasicLlmRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should not have set anything
|
||||
assert llm_request.config.response_schema is None
|
||||
assert llm_request.config.response_mime_type != 'application/json'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sets_model_name(self):
|
||||
"""Test that processor sets the model name correctly."""
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
llm_request = LlmRequest()
|
||||
processor = _BasicLlmRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should have set the model name
|
||||
assert llm_request.model == 'gemini-1.5-flash'
|
||||
@@ -0,0 +1,409 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for output schema processor functionality."""
|
||||
|
||||
import json
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.flows.llm_flows.single_flow import SingleFlow
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
|
||||
|
||||
class PersonSchema(BaseModel):
|
||||
"""Test schema for structured output."""
|
||||
|
||||
name: str = Field(description="A person's name")
|
||||
age: int = Field(description="A person's age")
|
||||
city: str = Field(description='The city they live in')
|
||||
|
||||
|
||||
def dummy_tool(query: str) -> str:
|
||||
"""A dummy tool for testing."""
|
||||
return f'Searched for: {query}'
|
||||
|
||||
|
||||
async def _create_invocation_context(agent: LlmAgent) -> InvocationContext:
|
||||
"""Helper to create InvocationContext for testing."""
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id='test-id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_schema_with_tools_validation_removed():
|
||||
"""Test that LlmAgent now allows output_schema with tools."""
|
||||
# This should not raise an error anymore
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[FunctionTool(func=dummy_tool)],
|
||||
)
|
||||
|
||||
assert agent.output_schema == PersonSchema
|
||||
assert len(agent.tools) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_processor_skips_output_schema_with_tools():
|
||||
"""Test that basic processor doesn't set output_schema when tools are present."""
|
||||
from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[FunctionTool(func=dummy_tool)],
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
|
||||
llm_request = LlmRequest()
|
||||
processor = _BasicLlmRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should not have set response_schema since agent has tools
|
||||
assert llm_request.config.response_schema is None
|
||||
assert llm_request.config.response_mime_type != 'application/json'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_processor_sets_output_schema_without_tools():
|
||||
"""Test that basic processor still sets output_schema when no tools are present."""
|
||||
from google.adk.flows.llm_flows.basic import _BasicLlmRequestProcessor
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[], # No tools
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
|
||||
llm_request = LlmRequest()
|
||||
processor = _BasicLlmRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should have set response_schema since agent has no tools
|
||||
assert llm_request.config.response_schema == PersonSchema
|
||||
assert llm_request.config.response_mime_type == 'application/json'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_schema_request_processor():
|
||||
"""Test that output schema processor adds set_model_response tool."""
|
||||
from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[FunctionTool(func=dummy_tool)],
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
|
||||
llm_request = LlmRequest()
|
||||
processor = _OutputSchemaRequestProcessor()
|
||||
|
||||
# Process the request
|
||||
events = []
|
||||
async for event in processor.run_async(invocation_context, llm_request):
|
||||
events.append(event)
|
||||
|
||||
# Should have added set_model_response tool
|
||||
assert 'set_model_response' in llm_request.tools_dict
|
||||
|
||||
# Should have added instruction about using set_model_response
|
||||
assert 'set_model_response' in llm_request.config.system_instruction
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_response_tool():
|
||||
"""Test the set_model_response tool functionality."""
|
||||
from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY
|
||||
from google.adk.tools.set_model_response_tool import SetModelResponseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
# Call the tool with valid data
|
||||
result = await tool.run_async(
|
||||
args={'name': 'John Doe', 'age': 30, 'city': 'New York'},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
# Verify the tool now returns dict directly
|
||||
assert result is not None
|
||||
assert result['name'] == 'John Doe'
|
||||
assert result['age'] == 30
|
||||
assert result['city'] == 'New York'
|
||||
|
||||
# Check that the response is no longer stored in session state
|
||||
stored_response = invocation_context.session.state.get(
|
||||
MODEL_JSON_RESPONSE_KEY
|
||||
)
|
||||
assert stored_response is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_schema_helper_functions():
|
||||
"""Test the helper functions for handling set_model_response."""
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows._output_schema_processor import create_final_model_response_event
|
||||
from google.adk.flows.llm_flows._output_schema_processor import get_structured_model_response
|
||||
from google.genai import types
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[FunctionTool(func=dummy_tool)],
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
|
||||
# Test get_structured_model_response with a function response event
|
||||
test_dict = {'name': 'Jane Smith', 'age': 25, 'city': 'Los Angeles'}
|
||||
test_json = '{"name": "Jane Smith", "age": 25, "city": "Los Angeles"}'
|
||||
|
||||
# Create a function response event with set_model_response
|
||||
function_response_event = Event(
|
||||
author='test_agent',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name='set_model_response', response=test_dict
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Test get_structured_model_response function
|
||||
extracted_json = get_structured_model_response(function_response_event)
|
||||
assert extracted_json == test_json
|
||||
|
||||
# Test create_final_model_response_event function
|
||||
final_event = create_final_model_response_event(invocation_context, test_json)
|
||||
assert final_event.author == 'test_agent'
|
||||
assert final_event.content.role == 'model'
|
||||
assert final_event.content.parts[0].text == test_json
|
||||
|
||||
# Test get_structured_model_response with non-set_model_response function
|
||||
other_function_response_event = Event(
|
||||
author='test_agent',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name='other_tool', response={'result': 'other response'}
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
extracted_json = get_structured_model_response(other_function_response_event)
|
||||
assert extracted_json is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_integration():
|
||||
"""Test the complete output schema with tools integration."""
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[FunctionTool(func=dummy_tool)],
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
|
||||
# Create a flow and test the processors
|
||||
flow = SingleFlow()
|
||||
llm_request = LlmRequest()
|
||||
|
||||
# Run all request processors
|
||||
async for event in flow._preprocess_async(invocation_context, llm_request):
|
||||
pass
|
||||
|
||||
# Verify set_model_response tool was added
|
||||
assert 'set_model_response' in llm_request.tools_dict
|
||||
|
||||
# Verify instruction was added
|
||||
assert 'set_model_response' in llm_request.config.system_instruction
|
||||
|
||||
# Verify output_schema was NOT set on the model config
|
||||
assert llm_request.config.response_schema is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_yields_both_events_for_set_model_response():
|
||||
"""Test that the flow yields both function response and final model response events."""
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
|
||||
from google.adk.tools.set_model_response_tool import SetModelResponseTool
|
||||
from google.genai import types
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
output_schema=PersonSchema,
|
||||
tools=[],
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
flow = BaseLlmFlow()
|
||||
|
||||
# Create a set_model_response tool and add it to the tools dict
|
||||
set_response_tool = SetModelResponseTool(PersonSchema)
|
||||
llm_request = LlmRequest()
|
||||
llm_request.tools_dict['set_model_response'] = set_response_tool
|
||||
|
||||
# Create a function call event (model calling the function)
|
||||
function_call_event = Event(
|
||||
author='test_agent',
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
name='set_model_response',
|
||||
args={
|
||||
'name': 'Test User',
|
||||
'age': 30,
|
||||
'city': 'Test City',
|
||||
},
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Test the postprocess function handling
|
||||
events = []
|
||||
async for event in flow._postprocess_handle_function_calls_async(
|
||||
invocation_context, function_call_event, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should yield exactly 2 events: function response + final model response
|
||||
assert len(events) == 2
|
||||
|
||||
# First event should be the function response
|
||||
first_event = events[0]
|
||||
assert first_event.get_function_responses()[0].name == 'set_model_response'
|
||||
# The response should be the dict returned by the tool
|
||||
assert first_event.get_function_responses()[0].response == {
|
||||
'name': 'Test User',
|
||||
'age': 30,
|
||||
'city': 'Test City',
|
||||
}
|
||||
|
||||
# Second event should be the final model response with JSON
|
||||
second_event = events[1]
|
||||
assert second_event.author == 'test_agent'
|
||||
assert second_event.content.role == 'model'
|
||||
assert (
|
||||
second_event.content.parts[0].text
|
||||
== '{"name": "Test User", "age": 30, "city": "Test City"}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_yields_only_function_response_for_normal_tools():
|
||||
"""Test that the flow yields only function response event for non-set_model_response tools."""
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
|
||||
from google.genai import types
|
||||
|
||||
agent = LlmAgent(
|
||||
name='test_agent',
|
||||
model='gemini-1.5-flash',
|
||||
tools=[FunctionTool(func=dummy_tool)],
|
||||
)
|
||||
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
flow = BaseLlmFlow()
|
||||
|
||||
# Create a dummy tool and add it to the tools dict
|
||||
dummy_function_tool = FunctionTool(func=dummy_tool)
|
||||
llm_request = LlmRequest()
|
||||
llm_request.tools_dict['dummy_tool'] = dummy_function_tool
|
||||
|
||||
# Create a function call event (model calling the dummy tool)
|
||||
function_call_event = Event(
|
||||
author='test_agent',
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
name='dummy_tool', args={'query': 'test query'}
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Test the postprocess function handling
|
||||
events = []
|
||||
async for event in flow._postprocess_handle_function_calls_async(
|
||||
invocation_context, function_call_event, llm_request
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should yield exactly 1 event: just the function response
|
||||
assert len(events) == 1
|
||||
|
||||
# Should be the function response from dummy_tool
|
||||
first_event = events[0]
|
||||
assert first_event.get_function_responses()[0].name == 'dummy_tool'
|
||||
assert first_event.get_function_responses()[0].response == {
|
||||
'result': 'Searched for: test query'
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for SetModelResponseTool."""
|
||||
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.set_model_response_tool import MODEL_JSON_RESPONSE_KEY
|
||||
from google.adk.tools.set_model_response_tool import SetModelResponseTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
import pytest
|
||||
|
||||
|
||||
class PersonSchema(BaseModel):
|
||||
"""Test schema for structured output."""
|
||||
|
||||
name: str = Field(description="A person's name")
|
||||
age: int = Field(description="A person's age")
|
||||
city: str = Field(description='The city they live in')
|
||||
|
||||
|
||||
class ComplexSchema(BaseModel):
|
||||
"""More complex test schema."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, str] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
async def _create_invocation_context(agent: LlmAgent) -> InvocationContext:
|
||||
"""Helper to create InvocationContext for testing."""
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
app_name='test_app', user_id='test_user'
|
||||
)
|
||||
return InvocationContext(
|
||||
invocation_id='test-id',
|
||||
agent=agent,
|
||||
session=session,
|
||||
session_service=session_service,
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
def test_tool_initialization_simple_schema():
|
||||
"""Test tool initialization with a simple schema."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
assert tool.output_schema == PersonSchema
|
||||
assert tool.name == 'set_model_response'
|
||||
assert 'Set your final response' in tool.description
|
||||
assert tool.func is not None
|
||||
|
||||
|
||||
def test_tool_initialization_complex_schema():
|
||||
"""Test tool initialization with a complex schema."""
|
||||
tool = SetModelResponseTool(ComplexSchema)
|
||||
|
||||
assert tool.output_schema == ComplexSchema
|
||||
assert tool.name == 'set_model_response'
|
||||
assert tool.func is not None
|
||||
|
||||
|
||||
def test_function_signature_generation():
|
||||
"""Test that function signature is correctly generated from schema."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(tool.func)
|
||||
|
||||
# Check that parameters match schema fields
|
||||
assert 'name' in sig.parameters
|
||||
assert 'age' in sig.parameters
|
||||
assert 'city' in sig.parameters
|
||||
|
||||
# All parameters should be keyword-only
|
||||
for param in sig.parameters.values():
|
||||
assert param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
|
||||
|
||||
def test_get_declaration():
|
||||
"""Test that tool declaration is properly generated."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
declaration = tool._get_declaration()
|
||||
|
||||
assert declaration is not None
|
||||
assert declaration.name == 'set_model_response'
|
||||
assert declaration.description is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_valid_data():
|
||||
"""Test tool execution with valid data."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
# Execute with valid data
|
||||
result = await tool.run_async(
|
||||
args={'name': 'Alice', 'age': 25, 'city': 'Seattle'},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
# Verify the tool now returns dict directly
|
||||
assert result is not None
|
||||
assert result['name'] == 'Alice'
|
||||
assert result['age'] == 25
|
||||
assert result['city'] == 'Seattle'
|
||||
|
||||
# Verify data is no longer stored in session state (old behavior)
|
||||
stored_response = invocation_context.session.state.get(
|
||||
MODEL_JSON_RESPONSE_KEY
|
||||
)
|
||||
assert stored_response is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_complex_schema():
|
||||
"""Test tool execution with complex schema."""
|
||||
tool = SetModelResponseTool(ComplexSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
# Execute with complex data
|
||||
result = await tool.run_async(
|
||||
args={
|
||||
'id': 123,
|
||||
'title': 'Test Item',
|
||||
'tags': ['tag1', 'tag2'],
|
||||
'metadata': {'key': 'value'},
|
||||
'is_active': False,
|
||||
},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
# Verify the tool now returns dict directly
|
||||
assert result is not None
|
||||
assert result['id'] == 123
|
||||
assert result['title'] == 'Test Item'
|
||||
assert result['tags'] == ['tag1', 'tag2']
|
||||
assert result['metadata'] == {'key': 'value'}
|
||||
assert result['is_active'] is False
|
||||
|
||||
# Verify data is no longer stored in session state (old behavior)
|
||||
stored_response = invocation_context.session.state.get(
|
||||
MODEL_JSON_RESPONSE_KEY
|
||||
)
|
||||
assert stored_response is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_validation_error():
|
||||
"""Test tool execution with invalid data raises validation error."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
# Execute with invalid data (wrong type for age)
|
||||
with pytest.raises(ValidationError):
|
||||
await tool.run_async(
|
||||
args={'name': 'Bob', 'age': 'not_a_number', 'city': 'Portland'},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_missing_required_field():
|
||||
"""Test tool execution with missing required field."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
# Execute with missing required field
|
||||
with pytest.raises(ValidationError):
|
||||
await tool.run_async(
|
||||
args={'name': 'Charlie', 'city': 'Denver'}, # Missing age
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_state_storage_key():
|
||||
"""Test that response is no longer stored in session state."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
result = await tool.run_async(
|
||||
args={'name': 'Diana', 'age': 35, 'city': 'Miami'},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
# Verify response is returned directly, not stored in session state
|
||||
assert result is not None
|
||||
assert result['name'] == 'Diana'
|
||||
assert result['age'] == 35
|
||||
assert result['city'] == 'Miami'
|
||||
|
||||
# Verify session state is no longer used
|
||||
assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_executions_return_latest():
|
||||
"""Test that multiple executions return latest response independently."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
agent = LlmAgent(name='test_agent', model='gemini-1.5-flash')
|
||||
invocation_context = await _create_invocation_context(agent)
|
||||
tool_context = ToolContext(invocation_context)
|
||||
|
||||
# First execution
|
||||
result1 = await tool.run_async(
|
||||
args={'name': 'First', 'age': 20, 'city': 'City1'},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
# Second execution should return its own response
|
||||
result2 = await tool.run_async(
|
||||
args={'name': 'Second', 'age': 30, 'city': 'City2'},
|
||||
tool_context=tool_context,
|
||||
)
|
||||
|
||||
# Verify each execution returns its own dict
|
||||
assert result1['name'] == 'First'
|
||||
assert result1['age'] == 20
|
||||
assert result1['city'] == 'City1'
|
||||
|
||||
assert result2['name'] == 'Second'
|
||||
assert result2['age'] == 30
|
||||
assert result2['city'] == 'City2'
|
||||
|
||||
# Verify session state is not used
|
||||
assert MODEL_JSON_RESPONSE_KEY not in invocation_context.session.state
|
||||
|
||||
|
||||
def test_function_return_value_consistency():
|
||||
"""Test that function return value matches run_async return value."""
|
||||
tool = SetModelResponseTool(PersonSchema)
|
||||
|
||||
# Direct function call
|
||||
direct_result = tool.func()
|
||||
|
||||
# Both should return the same value
|
||||
assert direct_result == 'Response set successfully.'
|
||||
Reference in New Issue
Block a user