test: Add unittest suites for testing HITL confirmation flow on runner level

PiperOrigin-RevId: 807327997
This commit is contained in:
Xinran (Sherry) Tang
2025-09-15 12:00:16 -07:00
committed by Copybara-Service
parent 402f3626b3
commit 4dbec15d26
2 changed files with 378 additions and 0 deletions
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,365 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for HITL flows with different agent structures."""
import copy
from unittest import mock
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.parallel_agent import ParallelAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
from google.genai.types import FunctionCall
from google.genai.types import FunctionResponse
from google.genai.types import GenerateContentResponse
from google.genai.types import Part
import pytest
from .. import testing_utils
def _create_llm_response_from_tools(
tools: list[FunctionTool],
) -> GenerateContentResponse:
"""Creates a mock LLM response containing a function call."""
parts = [
Part(function_call=FunctionCall(name=tool.name, args={}))
for tool in tools
]
return testing_utils.LlmResponse(
content=testing_utils.ModelContent(parts=parts)
)
def _create_llm_response_from_text(text: str) -> GenerateContentResponse:
"""Creates a mock LLM response containing text."""
return testing_utils.LlmResponse(
content=testing_utils.ModelContent(parts=[Part(text=text)])
)
def _test_request_confirmation_function(
tool_context: ToolContext,
) -> dict[str, str]:
"""A test tool function that requests confirmation."""
if not tool_context.tool_confirmation:
tool_context.request_confirmation(hint="test hint for request_confirmation")
return {"error": "test error for request_confirmation"}
return {"result": f"confirmed={tool_context.tool_confirmation.confirmed}"}
def _test_request_confirmation_function_with_custom_schema(
tool_context: ToolContext,
) -> dict[str, str]:
"""A test tool function that requests confirmation, but with a custom payload schema."""
if not tool_context.tool_confirmation:
tool_context.request_confirmation(
hint="test hint for request_confirmation with custom payload schema",
payload={
"test_custom_payload": {
"int_field": 0,
"str_field": "",
"bool_field": False,
}
},
)
return {"error": "test error for request_confirmation"}
return {
"result": f"confirmed={tool_context.tool_confirmation.confirmed}",
"custom_payload": tool_context.tool_confirmation.payload,
}
class BaseHITLTest:
"""Base class for HITL tests with common fixtures."""
@pytest.fixture
def runner(self, agent: BaseAgent) -> testing_utils.InMemoryRunner:
"""Provides an in-memory runner for the agent."""
return testing_utils.InMemoryRunner(root_agent=agent)
class TestHITLConfirmationFlowWithSingleAgent(BaseHITLTest):
"""Tests the HITL confirmation flow with a single LlmAgent."""
@pytest.fixture
def tools(self) -> list[FunctionTool]:
"""Provides the tools for the agent."""
return [FunctionTool(func=_test_request_confirmation_function)]
@pytest.fixture
def llm_responses(
self, tools: list[FunctionTool]
) -> list[GenerateContentResponse]:
"""Provides mock LLM responses for the tests."""
return [
_create_llm_response_from_tools(tools),
_create_llm_response_from_text("test llm response after tool call"),
_create_llm_response_from_text(
"test llm response after final tool call"
),
]
@pytest.fixture
def mock_model(
self, llm_responses: list[GenerateContentResponse]
) -> testing_utils.MockModel:
"""Provides a mock model with predefined responses."""
return testing_utils.MockModel(responses=llm_responses)
@pytest.fixture
def agent(
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
) -> LlmAgent:
"""Provides a single LlmAgent for the test."""
return LlmAgent(name="root_agent", model=mock_model, tools=tools)
@pytest.mark.asyncio
@pytest.mark.parametrize("tool_call_confirmed", [True, False])
async def test_confirmation_flow(
self,
runner: testing_utils.InMemoryRunner,
agent: LlmAgent,
tool_call_confirmed: bool,
):
"""Tests HITL flow where all tool calls are confirmed."""
user_query = testing_utils.UserContent("test user query")
events = await runner.run_async(user_query)
tools = agent.tools
expected_parts = [
(
agent.name,
Part(function_call=FunctionCall(name=tools[0].name, args={})),
),
(
agent.name,
Part(
function_call=FunctionCall(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args={
"originalFunctionCall": {
"name": tools[0].name,
"id": mock.ANY,
"args": {},
},
"toolConfirmation": {
"hint": "test hint for request_confirmation",
"confirmed": False,
},
},
)
),
),
(
agent.name,
Part(
function_response=FunctionResponse(
name=tools[0].name,
response={"error": "test error for request_confirmation"},
)
),
),
(agent.name, "test llm response after tool call"),
]
simplified = testing_utils.simplify_events(copy.deepcopy(events))
for i, (agent_name, part) in enumerate(expected_parts):
assert simplified[i][0] == agent_name
assert simplified[i][1] == part
ask_for_confirmation_function_call_id = (
events[1].content.parts[0].function_call.id
)
user_confirmation = testing_utils.UserContent(
Part(
function_response=FunctionResponse(
id=ask_for_confirmation_function_call_id,
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
response={"confirmed": tool_call_confirmed},
)
)
)
events = await runner.run_async(user_confirmation)
expected_parts_final = [
(
agent.name,
Part(
function_response=FunctionResponse(
name=tools[0].name,
response={"result": f"confirmed={tool_call_confirmed}"},
)
),
),
(agent.name, "test llm response after final tool call"),
]
assert (
testing_utils.simplify_events(copy.deepcopy(events))
== expected_parts_final
)
class TestHITLConfirmationFlowWithCustomPayloadSchema(BaseHITLTest):
"""Tests the HITL confirmation flow with a single agent, for custom confirmation payload schema."""
@pytest.fixture
def tools(self) -> list[FunctionTool]:
"""Provides the tools for the agent."""
return [
FunctionTool(
func=_test_request_confirmation_function_with_custom_schema
)
]
@pytest.fixture
def llm_responses(
self, tools: list[FunctionTool]
) -> list[GenerateContentResponse]:
"""Provides mock LLM responses for the tests."""
return [
_create_llm_response_from_tools(tools),
_create_llm_response_from_text("test llm response after tool call"),
_create_llm_response_from_text(
"test llm response after final tool call"
),
]
@pytest.fixture
def mock_model(
self, llm_responses: list[GenerateContentResponse]
) -> testing_utils.MockModel:
"""Provides a mock model with predefined responses."""
return testing_utils.MockModel(responses=llm_responses)
@pytest.fixture
def agent(
self, mock_model: testing_utils.MockModel, tools: list[FunctionTool]
) -> LlmAgent:
"""Provides a single LlmAgent for the test."""
return LlmAgent(name="root_agent", model=mock_model, tools=tools)
@pytest.mark.asyncio
@pytest.mark.parametrize("tool_call_confirmed", [True, False])
async def test_confirmation_flow(
self,
runner: testing_utils.InMemoryRunner,
agent: LlmAgent,
tool_call_confirmed: bool,
):
"""Tests HITL flow with custom payload schema."""
tools = agent.tools
user_query = testing_utils.UserContent("test user query")
events = await runner.run_async(user_query)
expected_parts = [
(
agent.name,
Part(function_call=FunctionCall(name=tools[0].name, args={})),
),
(
agent.name,
Part(
function_call=FunctionCall(
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
args={
"originalFunctionCall": {
"name": tools[0].name,
"id": mock.ANY,
"args": {},
},
"toolConfirmation": {
"hint": (
"test hint for request_confirmation with"
" custom payload schema"
),
"confirmed": False,
"payload": {
"test_custom_payload": {
"int_field": 0,
"str_field": "",
"bool_field": False,
}
},
},
},
)
),
),
(
agent.name,
Part(
function_response=FunctionResponse(
name=tools[0].name,
response={"error": "test error for request_confirmation"},
)
),
),
(agent.name, "test llm response after tool call"),
]
simplified = testing_utils.simplify_events(copy.deepcopy(events))
for i, (agent_name, part) in enumerate(expected_parts):
assert simplified[i][0] == agent_name
assert simplified[i][1] == part
ask_for_confirmation_function_call_id = (
events[1].content.parts[0].function_call.id
)
custom_payload = {
"test_custom_payload": {
"int_field": 123,
"str_field": "test_str",
"bool_field": True,
}
}
user_confirmation = testing_utils.UserContent(
Part(
function_response=FunctionResponse(
id=ask_for_confirmation_function_call_id,
name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME,
response={
"confirmed": tool_call_confirmed,
"payload": custom_payload,
},
)
)
)
events = await runner.run_async(user_confirmation)
expected_response = {
"result": f"confirmed={tool_call_confirmed}",
"custom_payload": custom_payload,
}
expected_parts_final = [
(
agent.name,
Part(
function_response=FunctionResponse(
name=tools[0].name,
response=expected_response,
)
),
),
(agent.name, "test llm response after final tool call"),
]
assert (
testing_utils.simplify_events(copy.deepcopy(events))
== expected_parts_final
)