feat: Adds Static User Simulator and User Simulator Provider

Details:
- Adds the `StaticUserSimulator` which implements the current functionality of supplying a fixed set of user prompts for an EvalCase.
- Adds the `UserSimulatorProvider` which determines the type of user simulator required for an EvalCase (StaticUserSimulator or LlmBackedUserSimulator).
- Integrates the UserSimulatorProvider and UserSimulator into the CLI and evaluation infrastructure.
- Updates and adds unit tests for the new functionality.
- Miscellaneous updates to lay groundwork for a full implementation of the LlmBackedUserSimulator in the future.
PiperOrigin-RevId: 822198401
This commit is contained in:
Google Team Member
2025-10-21 11:14:45 -07:00
committed by Copybara-Service
parent 4a842c5a13
commit aeaec859bf
17 changed files with 727 additions and 62 deletions
+6
View File
@@ -549,6 +549,7 @@ def cli_eval(
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..evaluation.user_simulator_provider import UserSimulatorProvider
from .cli_eval import _collect_eval_results
from .cli_eval import _collect_inferences
from .cli_eval import get_root_agent
@@ -638,11 +639,16 @@ def cli_eval(
)
)
user_simulator_provider = UserSimulatorProvider(
user_simulator_config=eval_config.user_simulator_config
)
try:
eval_service = LocalEvalService(
root_agent=root_agent,
eval_sets_manager=eval_sets_manager,
eval_set_results_manager=eval_set_results_manager,
user_simulator_provider=user_simulator_provider,
)
inference_results = asyncio.run(
@@ -50,6 +50,7 @@ from .eval_sets_manager import EvalSetsManager
from .evaluator import EvalStatus
from .in_memory_eval_sets_manager import InMemoryEvalSetsManager
from .local_eval_sets_manager import convert_eval_set_to_pydanctic_schema
from .user_simulator_provider import UserSimulatorProvider
logger = logging.getLogger("google_adk." + __name__)
@@ -149,12 +150,17 @@ class AgentEvaluator:
)
eval_metrics = get_eval_metrics_from_config(eval_config)
user_simulator_provider = UserSimulatorProvider(
user_simulator_config=eval_config.user_simulator_config
)
# Step 1: Perform evals, basically inferencing and evaluation of metrics
eval_results_by_eval_id = await AgentEvaluator._get_eval_results_by_eval_id(
agent_for_eval=agent_for_eval,
eval_set=eval_set,
eval_metrics=eval_metrics,
num_runs=num_runs,
user_simulator_provider=user_simulator_provider,
)
# Step 2: Post-process the results!
@@ -518,6 +524,7 @@ class AgentEvaluator:
eval_set: EvalSet,
eval_metrics: list[EvalMetric],
num_runs: int,
user_simulator_provider: UserSimulatorProvider,
) -> dict[str, list[EvalCaseResult]]:
"""Returns EvalCaseResults grouped by eval case id.
@@ -541,6 +548,7 @@ class AgentEvaluator:
eval_sets_manager=AgentEvaluator._get_eval_sets_manager(
app_name=app_name, eval_set=eval_set
),
user_simulator_provider=user_simulator_provider,
)
inference_requests = [
@@ -14,22 +14,19 @@
from __future__ import annotations
from typing import Union
from google.genai import types as genai_types
from pydantic import Field
from .common import EvalBaseModel
class ConversationScenario(EvalBaseModel):
"""Scenario for a conversation between a simulated user and the Agent."""
"""Scenario for a conversation between a simulated user and the Agent under test."""
starting_prompt: Union[str, genai_types.Content]
starting_prompt: str
"""Starting prompt for the conversation.
This prompt acts as the first user message that is given to the Agent. Any
subsequent user messages are obtained by the system that is simulating the
This prompt acts as the fixed first user message that is given to the Agent.
Any subsequent user messages are obtained by the system that is simulating the
user.
"""
+11 -1
View File
@@ -20,6 +20,7 @@ from typing import Union
from google.genai import types as genai_types
from pydantic import Field
from pydantic import model_validator
from typing_extensions import TypeAlias
from .app_details import AppDetails
@@ -121,7 +122,7 @@ class SessionInput(EvalBaseModel):
StaticConversation: TypeAlias = list[Invocation]
"""A conversation where user's query for each invocation is already specified."""
"""A conversation where the user's queries for each invocation are already specified."""
class EvalCase(EvalBaseModel):
@@ -158,6 +159,15 @@ class EvalCase(EvalBaseModel):
)
"""A list of rubrics that are applicable to all the invocations in the conversation of this eval case."""
@model_validator(mode="after")
def ensure_conversation_xor_conversation_scenario(self) -> EvalCase:
if (self.conversation is None) == (self.conversation_scenario is None):
raise ValueError(
"Exactly one of conversation and conversation_scenario must be"
" provided in an EvalCase."
)
return self
def get_all_tool_calls(
intermediate_data: Optional[IntermediateDataType],
+1 -1
View File
@@ -73,7 +73,7 @@ the third one uses `LlmAsAJudgeCriterion`.
user_simulator_config: Optional[BaseUserSimulatorConfig] = Field(
default=None,
description="""Config to be used by the user simulator.""",
description="Config to be used by the user simulator.",
)
@@ -14,11 +14,14 @@
from __future__ import annotations
import copy
import importlib
from typing import Any
from typing import AsyncGenerator
from typing import Optional
import uuid
from google.genai.types import Content
from pydantic import BaseModel
from ..agents.llm_agent import Agent
@@ -41,6 +44,9 @@ from .eval_case import InvocationEvents
from .eval_case import SessionInput
from .eval_set import EvalSet
from .request_intercepter_plugin import _RequestIntercepterPlugin
from .user_simulator import Status as UserSimulatorStatus
from .user_simulator import UserSimulator
from .user_simulator_provider import UserSimulatorProvider
_USER_AUTHOR = "user"
_DEFAULT_AUTHOR = "agent"
@@ -79,11 +85,13 @@ class EvaluationGenerator:
results = []
for eval_case in eval_set.eval_cases:
# assume only static conversations are needed
user_simulator = UserSimulatorProvider().provide(eval_case)
responses = []
for _ in range(repeat_num):
response_invocations = await EvaluationGenerator._process_query(
eval_case.conversation,
agent_module_path,
user_simulator,
agent_name,
eval_case.session_input,
)
@@ -123,8 +131,8 @@ class EvaluationGenerator:
@staticmethod
async def _process_query(
invocations: list[Invocation],
module_name: str,
user_simulator: UserSimulator,
agent_name: Optional[str] = None,
initial_session: Optional[SessionInput] = None,
) -> list[Invocation]:
@@ -141,13 +149,44 @@ class EvaluationGenerator:
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
return await EvaluationGenerator._generate_inferences_from_root_agent(
invocations, agent_to_evaluate, reset_func, initial_session
agent_to_evaluate,
user_simulator=user_simulator,
reset_func=reset_func,
initial_session=initial_session,
)
@staticmethod
async def _generate_inferences_for_single_user_invocation(
runner: Runner,
user_id: str,
session_id: str,
user_content: Content,
) -> AsyncGenerator[Event, None]:
invocation_id = None
async with Aclosing(
runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=user_content,
)
) as agen:
async for event in agen:
if not invocation_id:
invocation_id = event.invocation_id
yield Event(
content=user_content,
author=_USER_AUTHOR,
invocation_id=invocation_id,
)
yield event
@staticmethod
async def _generate_inferences_from_root_agent(
invocations: list[Invocation],
root_agent: Agent,
user_simulator: UserSimulator,
reset_func: Optional[Any] = None,
initial_session: Optional[SessionInput] = None,
session_id: Optional[str] = None,
@@ -155,7 +194,8 @@ class EvaluationGenerator:
artifact_service: Optional[BaseArtifactService] = None,
memory_service: Optional[BaseMemoryService] = None,
) -> list[Invocation]:
"""Scrapes the root agent given the list of Invocations."""
"""Scrapes the root agent in coordination with the user simulator."""
if not session_service:
session_service = InMemorySessionService()
@@ -194,29 +234,19 @@ class EvaluationGenerator:
plugins=[request_intercepter_plugin],
) as runner:
events = []
for invocation in invocations:
user_content = invocation.user_content
invocation_id = None
async with Aclosing(
runner.run_async(
user_id=user_id, session_id=session_id, new_message=user_content
)
) as agen:
async for event in agen:
if not invocation_id:
invocation_id = event.invocation_id
events.append(
Event(
content=user_content,
author=_USER_AUTHOR,
invocation_id=invocation_id,
)
)
while True:
next_user_message = await user_simulator.get_next_user_message(
copy.deepcopy(events)
)
if next_user_message.status == UserSimulatorStatus.SUCCESS:
async for (
event
) in EvaluationGenerator._generate_inferences_for_single_user_invocation(
runner, user_id, session_id, next_user_message.user_message
):
events.append(event)
else: # no message generated
break
app_details_by_invocation_id = (
EvaluationGenerator._get_app_details_by_invocation_id(
@@ -22,7 +22,9 @@ from pydantic import Field
from typing_extensions import override
from ..events.event import Event
from ..models.registry import LLMRegistry
from ..utils.feature_decorator import experimental
from .conversation_scenarios import ConversationScenario
from .evaluator import Evaluator
from .user_simulator import BaseUserSimulatorConfig
from .user_simulator import NextUserMessage
@@ -37,52 +39,59 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
description="The model to use for user simulation.",
)
model_config: Optional[genai_types.GenerateContentConfig] = Field(
default=genai_types.GenerateContentConfig,
model_configuration: genai_types.GenerateContentConfig = Field(
default_factory=genai_types.GenerateContentConfig,
description="The configuration for the model.",
)
max_allowed_invocations: int = Field(
default=20,
description="""Maximum number of invocations allowed by the simulated
interaction. This property allows us to stop a run-off conversation, where the
agent and the user simulator get into an never ending loop.
interaction. This property allows us to stop a run-off conversation, where the
agent and the user simulator get into a never ending loop. The initial fixed
prompt is also counted as an invocation.
(Not recommended)If you don't want a limit, you can set the value to -1.
""",
(Not recommended) If you don't want a limit, you can set the value to -1.""",
)
@experimental
class LlmBackedUserSimulator(UserSimulator):
"""A UserSimulator that uses a LLM to generate messages on behalf of the user."""
"""A UserSimulator that uses an LLM to generate messages on behalf of the user."""
config_type: ClassVar[type[LlmBackedUserSimulatorConfig]] = (
LlmBackedUserSimulatorConfig
)
def __init__(self, *, config: BaseUserSimulatorConfig):
def __init__(
self,
*,
config: BaseUserSimulatorConfig,
conversation_scenario: ConversationScenario,
):
super().__init__(config, config_type=LlmBackedUserSimulator.config_type)
self._conversation_scenario = conversation_scenario
@override
async def get_next_user_message(
self,
conversation_plan: str,
events: list[Event],
) -> NextUserMessage:
"""Returns the next user message to send to the agent with help from a LLM.
Args:
conversation_plan: A plan that user simulation system needs to follow as
it plays out the conversation.
events: The unaltered conversation history between the user and the
agent(s) under evaluation.
Returns:
A NextUserMessage object containing the next user message to send to the
agent, or a status indicating why no message was generated.
"""
raise NotImplementedError()
@override
def get_simulation_evaluator(
self,
) -> Evaluator:
) -> Optional[Evaluator]:
"""Returns an Evaluator that evaluates if the simulation was successful or not."""
raise NotImplementedError()
@@ -22,6 +22,8 @@ from typing import Callable
from typing import Optional
import uuid
from google.genai.types import Content
from google.genai.types import Part
from typing_extensions import override
from ..agents.base_agent import BaseAgent
@@ -51,6 +53,7 @@ from .evaluator import EvalStatus
from .evaluator import EvaluationResult
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
from .metric_evaluator_registry import MetricEvaluatorRegistry
from .user_simulator_provider import UserSimulatorProvider
logger = logging.getLogger('google_adk.' + __name__)
@@ -74,6 +77,7 @@ class LocalEvalService(BaseEvalService):
artifact_service: Optional[BaseArtifactService] = None,
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
session_id_supplier: Callable[[], str] = _get_session_id,
user_simulator_provider: UserSimulatorProvider = UserSimulatorProvider(),
):
self._root_agent = root_agent
self._eval_sets_manager = eval_sets_manager
@@ -87,6 +91,7 @@ class LocalEvalService(BaseEvalService):
self._artifact_service = artifact_service
self._eval_set_results_manager = eval_set_results_manager
self._session_id_supplier = session_id_supplier
self._user_simulator_provider = user_simulator_provider
@override
async def perform_inference(
@@ -211,6 +216,48 @@ class LocalEvalService(BaseEvalService):
# would be the score for the eval case.
overall_eval_metric_results = []
user_id = (
eval_case.session_input.user_id
if eval_case.session_input and eval_case.session_input.user_id
else 'test_user_id'
)
if eval_case.conversation_scenario:
logger.warning(
'Skipping evaluation of variable-length conversation scenario in eval'
' set/case %s/%s.',
inference_result.eval_set_id,
inference_result.eval_case_id,
)
for actual_invocation in inference_result.inferences:
eval_metric_result_per_invocation.append(
EvalMetricResultPerInvocation(
actual_invocation=actual_invocation,
expected_invocation=Invocation(
user_content=actual_invocation.user_content,
final_response=Content(
parts=[Part(text='N/A')], role='model'
),
),
)
)
eval_case_result = EvalCaseResult(
eval_set_file=inference_result.eval_set_id,
eval_set_id=inference_result.eval_set_id,
eval_id=inference_result.eval_case_id,
final_eval_status=EvalStatus.NOT_EVALUATED,
overall_eval_metric_results=overall_eval_metric_results,
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
session_id=inference_result.session_id,
session_details=await self._session_service.get_session(
app_name=inference_result.app_name,
user_id=user_id,
session_id=inference_result.session_id,
),
user_id=user_id,
)
return (inference_result, eval_case_result)
if len(inference_result.inferences) != len(eval_case.conversation):
raise ValueError(
'Inferences should match conversations in eval case. Found'
@@ -281,11 +328,6 @@ class LocalEvalService(BaseEvalService):
final_eval_status = self._generate_final_eval_status(
overall_eval_metric_results
)
user_id = (
eval_case.session_input.user_id
if eval_case.session_input and eval_case.session_input.user_id
else 'test_user_id'
)
eval_case_result = EvalCaseResult(
eval_set_file=inference_result.eval_set_id,
@@ -373,8 +415,8 @@ class LocalEvalService(BaseEvalService):
try:
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
user_simulator=self._user_simulator_provider.provide(eval_case),
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
@@ -0,0 +1,79 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Optional
from typing_extensions import override
from ..events.event import Event
from ..utils.feature_decorator import experimental
from .eval_case import StaticConversation
from .evaluator import Evaluator
from .user_simulator import BaseUserSimulatorConfig
from .user_simulator import NextUserMessage
from .user_simulator import Status
from .user_simulator import UserSimulator
logger = logging.getLogger("google_adk." + __name__)
@experimental
class StaticUserSimulator(UserSimulator):
"""A UserSimulator that returns a static list of user messages."""
def __init__(self, *, static_conversation: StaticConversation):
super().__init__(
BaseUserSimulatorConfig(), config_type=BaseUserSimulatorConfig
)
self.static_conversation = static_conversation
self.invocation_idx = 0
@override
async def get_next_user_message(
self,
events: list[Event],
) -> NextUserMessage:
"""Returns the next user message to send to the agent from a static list.
Args:
events: The unaltered conversation history between the user and the
agent(s) under evaluation.
Returns:
A NextUserMessage object containing the next user message to send to the
agent, or a status indicating why no message was generated.
"""
# check if we have reached the end of the list of invocations
if self.invocation_idx >= len(self.static_conversation):
return NextUserMessage(status=Status.STOP_SIGNAL_DETECTED)
# return the next message in the static list
next_user_content = self.static_conversation[
self.invocation_idx
].user_content
self.invocation_idx += 1
return NextUserMessage(
status=Status.SUCCESS,
user_message=next_user_content,
)
@override
def get_simulation_evaluator(
self,
) -> Optional[Evaluator]:
"""The StaticUserSimulator does not require an evaluator."""
return None
+22 -8
View File
@@ -23,6 +23,7 @@ from pydantic import alias_generators
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from pydantic import ValidationError
from ..events.event import Event
@@ -32,7 +33,7 @@ from .evaluator import Evaluator
class BaseUserSimulatorConfig(BaseModel):
"""Base class for configurations pertaining to User Simulator."""
"""Base class for configurations pertaining to user simulator."""
model_config = ConfigDict(
alias_generator=alias_generators.to_camel,
@@ -60,13 +61,25 @@ not able to do so."""
)
user_message: Optional[genai_types.Content] = Field(
description="""The next user message."""
description="""The next user message.""", default=None
)
@model_validator(mode="after")
def ensure_user_message_iff_success(self) -> NextUserMessage:
if (self.status == Status.SUCCESS) == (self.user_message is None):
raise ValueError(
"A user_message should be provided if and only if the status is"
" SUCCESS"
)
return self
@experimental
class UserSimulator(ABC):
"""A user simulator for the purposes of automating interaction with an Agent."""
"""A user simulator for the purposes of automating interaction with an Agent.
Typically, you must create one user simulator instance per eval case.
"""
def __init__(
self,
@@ -82,21 +95,22 @@ class UserSimulator(ABC):
async def get_next_user_message(
self,
conversation_plan: str,
events: list[Event],
) -> NextUserMessage:
"""Returns the next user message to send to the agent.
Args:
conversation_plan: A plan that user simulation system needs to follow as
it plays out the conversation.
events: The unaltered conversation history between the user and the
agent(s) under evaluation.
Returns:
A NextUserMessage object containing the next user message to send to the
agent, or a status indicating why no message was generated.
"""
raise NotImplementedError()
def get_simulation_evaluator(
self,
) -> Evaluator:
"""Returns an instnace of an Evaluator that evaluates if the simulation was successful or not."""
) -> Optional[Evaluator]:
"""Returns an instance of an Evaluator that evaluates if the user simulation was successful or not."""
raise NotImplementedError()
@@ -0,0 +1,77 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from ..utils.feature_decorator import experimental
from .eval_case import EvalCase
from .llm_backed_user_simulator import LlmBackedUserSimulator
from .static_user_simulator import StaticUserSimulator
from .user_simulator import BaseUserSimulatorConfig
from .user_simulator import UserSimulator
@experimental
class UserSimulatorProvider:
"""Provides a UserSimulator instance per EvalCase, mixing configuration data
from the EvalConfig with per-EvalCase conversation data."""
def __init__(
self,
user_simulator_config: Optional[BaseUserSimulatorConfig] = None,
):
if user_simulator_config is None:
user_simulator_config = BaseUserSimulatorConfig()
elif not isinstance(user_simulator_config, BaseUserSimulatorConfig):
# assume that the user simulator will fully validate the config it gets.
raise ValueError(f"Expect config of type `{BaseUserSimulatorConfig}`.")
self._user_simulator_config = user_simulator_config
def provide(self, eval_case: EvalCase) -> UserSimulator:
"""Provide an appropriate user simulator based on the type of conversation data in the EvalCase
Args:
eval_case: An EvalCase containing a `conversation` xor a
`conversation_scenario`.
Returns:
A StaticUserSimulator or an LlmBackedUserSimulator based on the type of
the conversation data.
Raises:
ValueError: If no conversation data or multiple types of conversation data
are provided.
"""
if eval_case.conversation is None:
if eval_case.conversation_scenario is None:
raise ValueError(
"Neither static invocations nor conversation scenario provided in"
" EvalCase. Provide exactly one."
)
return LlmBackedUserSimulator(
config=self._user_simulator_config,
conversation_scenario=eval_case.conversation_scenario,
)
else: # eval_case.conversation is not None
if eval_case.conversation_scenario is not None:
raise ValueError(
"Both static invocations and conversation scenario provided in"
" EvalCase. Provide exactly one."
)
return StaticUserSimulator(static_conversation=eval_case.conversation)
@@ -14,6 +14,8 @@
from __future__ import annotations
from google.adk.evaluation.conversation_scenarios import ConversationScenario
from google.adk.evaluation.eval_case import EvalCase
from google.adk.evaluation.eval_case import get_all_tool_calls
from google.adk.evaluation.eval_case import get_all_tool_calls_with_responses
from google.adk.evaluation.eval_case import get_all_tool_responses
@@ -246,3 +248,36 @@ def test_get_all_tool_calls_with_responses_with_invocation_events():
(tool_call1, tool_response1),
(tool_call2, None),
]
def test_conversation_and_conversation_scenario_mutual_exclusion():
"""Tests the ensure_conversation_xor_conversation_scenario validator."""
test_conversation_scenario = ConversationScenario(
starting_prompt='', conversation_plan=''
)
with pytest.raises(
ValueError,
match=(
'Exactly one of conversation and conversation_scenario must be'
' provided in an EvalCase.'
),
):
EvalCase(eval_id='test_id')
with pytest.raises(
ValueError,
match=(
'Exactly one of conversation and conversation_scenario must be'
' provided in an EvalCase.'
),
):
EvalCase(
eval_id='test_id',
conversation=[],
conversation_scenario=test_conversation_scenario,
)
# these two should not cause exceptions
EvalCase(eval_id='test_id', conversation=[])
EvalCase(eval_id='test_id', conversation_scenario=test_conversation_scenario)
@@ -18,9 +18,13 @@ from google.adk.evaluation.app_details import AgentDetails
from google.adk.evaluation.app_details import AppDetails
from google.adk.evaluation.evaluation_generator import EvaluationGenerator
from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin
from google.adk.evaluation.user_simulator import NextUserMessage
from google.adk.evaluation.user_simulator import Status as UserSimulatorStatus
from google.adk.evaluation.user_simulator import UserSimulator
from google.adk.events.event import Event
from google.adk.models.llm_request import LlmRequest
from google.genai import types
import pytest
def _build_event(
@@ -324,3 +328,130 @@ class TestGetAppDetailsByInvocationId:
}
assert app_details == expected_app_details
assert mock_request_intercepter.get_model_request.call_count == 3
class TestGenerateInferencesForSingleUserInvocation:
"""Test cases for EvaluationGenerator._generate_inferences_for_single_user_invocation method."""
@pytest.mark.asyncio
async def test_generate_inferences_with_mock_runner(self, mocker):
"""Tests inference generation with a mocked runner."""
runner = mocker.MagicMock()
agent_parts = [types.Part(text="Agent response")]
async def mock_run_async(*args, **kwargs):
yield _build_event(
author="agent",
parts=agent_parts,
invocation_id="inv1",
)
runner.run_async.return_value = mock_run_async()
user_content = types.Content(parts=[types.Part(text="User query")])
events = [
event
async for event in EvaluationGenerator._generate_inferences_for_single_user_invocation(
runner, "test_user", "test_session", user_content
)
]
assert len(events) == 2
assert events[0].author == "user"
assert events[0].content == user_content
assert events[0].invocation_id == "inv1"
assert events[1].author == "agent"
assert events[1].content.parts == agent_parts
runner.run_async.assert_called_once_with(
user_id="test_user",
session_id="test_session",
new_message=user_content,
)
@pytest.fixture
def mock_runner(mocker):
"""Provides a mock Runner for testing."""
mock_runner_cls = mocker.patch(
"google.adk.evaluation.evaluation_generator.Runner"
)
mock_runner_instance = mocker.AsyncMock()
mock_runner_instance.__aenter__.return_value = mock_runner_instance
mock_runner_cls.return_value = mock_runner_instance
yield mock_runner_instance
@pytest.fixture
def mock_session_service(mocker):
"""Provides a mock InMemorySessionService for testing."""
mock_session_service_cls = mocker.patch(
"google.adk.evaluation.evaluation_generator.InMemorySessionService"
)
mock_session_service_instance = mocker.MagicMock()
mock_session_service_instance.create_session = mocker.AsyncMock()
mock_session_service_cls.return_value = mock_session_service_instance
yield mock_session_service_instance
class TestGenerateInferencesFromRootAgent:
"""Test cases for EvaluationGenerator._generate_inferences_from_root_agent method."""
@pytest.mark.asyncio
async def test_generates_inferences_with_user_simulator(
self, mocker, mock_runner, mock_session_service
):
"""Tests that inferences are generated by interacting with a user simulator."""
mock_agent = mocker.MagicMock()
mock_user_sim = mocker.MagicMock(spec=UserSimulator)
# Mock user simulator will produce one message, then stop.
async def get_next_user_message_side_effect(*args, **kwargs):
if mock_user_sim.get_next_user_message.call_count == 1:
return NextUserMessage(
status=UserSimulatorStatus.SUCCESS,
user_message=types.Content(parts=[types.Part(text="message 1")]),
)
return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED)
mock_user_sim.get_next_user_message = mocker.AsyncMock(
side_effect=get_next_user_message_side_effect
)
mock_generate_inferences = mocker.patch(
"google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_for_single_user_invocation"
)
mocker.patch(
"google.adk.evaluation.evaluation_generator.EvaluationGenerator._get_app_details_by_invocation_id"
)
mocker.patch(
"google.adk.evaluation.evaluation_generator.EvaluationGenerator.convert_events_to_eval_invocations"
)
# Each call to _generate_inferences_for_single_user_invocation will
# yield one user and one agent event.
async def mock_generate_inferences_side_effect(
runner, user_id, session_id, user_content
):
yield _build_event("user", user_content.parts, "inv1")
yield _build_event("agent", [types.Part(text="agent_response")], "inv1")
mock_generate_inferences.side_effect = mock_generate_inferences_side_effect
await EvaluationGenerator._generate_inferences_from_root_agent(
root_agent=mock_agent,
user_simulator=mock_user_sim,
)
# Check that user simulator was called until it stopped.
assert mock_user_sim.get_next_user_message.call_count == 2
# Check that we generated inferences for each user message.
assert mock_generate_inferences.call_count == 1
# Check the content of the user messages passed to inference generation
mock_generate_inferences.assert_called_once()
called_with_content = mock_generate_inferences.call_args.args[3]
assert called_with_content.parts[0].text == "message 1"
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import sys
@@ -246,6 +248,7 @@ async def test_evaluate_success(
mock_eval_case = mocker.MagicMock(spec=EvalCase)
mock_eval_case.conversation = []
mock_eval_case.conversation_scenario = None
mock_eval_case.session_input = None
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
@@ -321,6 +324,7 @@ async def test_evaluate_single_inference_result(
invocation.model_copy(deep=True),
invocation.model_copy(deep=True),
]
mock_eval_case.conversation_scenario = None
mock_eval_case.session_input = None
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
@@ -352,6 +356,51 @@ async def test_evaluate_single_inference_result(
assert metric_result.eval_status == EvalStatus.PASSED
@pytest.mark.asyncio
async def test_evaluate_single_inference_result_skipped_for_conversation_scenario(
eval_service, mock_eval_sets_manager, mocker
):
"""To be removed once evaluation is implemented for conversation scenarios."""
invocation = Invocation(
user_content=genai_types.Content(
parts=[genai_types.Part(text="test user content.")]
),
final_response=genai_types.Content(
parts=[genai_types.Part(text="test final response.")]
),
)
inference_result = InferenceResult(
app_name="test_app",
eval_set_id="test_eval_set",
eval_case_id="case1",
inferences=[invocation.model_copy(deep=True)],
session_id="session1",
)
eval_metric = EvalMetric(metric_name="fake_metric", threshold=0.5)
evaluate_config = EvaluateConfig(eval_metrics=[eval_metric], parallelism=1)
mock_eval_case = mocker.MagicMock(spec=EvalCase)
mock_eval_case.conversation = None
mock_eval_case.conversation_scenario = mocker.MagicMock()
mock_eval_case.session_input = None
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
_, result = await eval_service._evaluate_single_inference_result(
inference_result=inference_result, evaluate_config=evaluate_config
)
assert isinstance(result, EvalCaseResult)
assert result.eval_id == "case1"
assert result.final_eval_status == EvalStatus.NOT_EVALUATED
assert not result.overall_eval_metric_results
assert len(result.eval_metric_result_per_invocation) == 1
invocation_result = result.eval_metric_result_per_invocation[0]
assert not invocation_result.eval_metric_results
assert (
invocation_result.expected_invocation.final_response.parts[0].text
== "N/A"
)
def test_generate_final_eval_status_doesn_t_throw_on(eval_service):
# How to fix if this test case fails?
# This test case has failed mainly because a new EvalStatus got added. You
@@ -0,0 +1,54 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.evaluation import static_user_simulator
from google.adk.evaluation import user_simulator
from google.adk.evaluation.eval_case import Invocation
from google.genai import types
import pytest
class TestStaticUserSimulator:
"""Test cases for StaticUserSimulator."""
@pytest.mark.asyncio
async def test_get_next_user_message(self):
"""Tests that the provided messages are returned in order followed by the stop signal."""
conversation = [
Invocation(
invocation_id="inv1",
user_content=types.Content(parts=[types.Part(text="message 1")]),
),
Invocation(
invocation_id="inv2",
user_content=types.Content(parts=[types.Part(text="message 2")]),
),
]
simulator = static_user_simulator.StaticUserSimulator(
static_conversation=conversation
)
next_message_1 = await simulator.get_next_user_message(events=[])
assert user_simulator.Status.SUCCESS == next_message_1.status
assert "message 1" == next_message_1.user_message.parts[0].text
next_message_2 = await simulator.get_next_user_message(events=[])
assert user_simulator.Status.SUCCESS == next_message_2.status
assert "message 2" == next_message_2.user_message.parts[0].text
next_message_3 = await simulator.get_next_user_message(events=[])
assert user_simulator.Status.STOP_SIGNAL_DETECTED == next_message_3.status
assert next_message_3.user_message is None
@@ -0,0 +1,45 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.evaluation.user_simulator import NextUserMessage
from google.adk.evaluation.user_simulator import Status
from google.genai.types import Content
import pytest
def test_next_user_message_validation():
"""Tests post-init validation of NextUserMessage."""
with pytest.raises(
ValueError,
match=(
"A user_message should be provided if and only if the status is"
" SUCCESS"
),
):
NextUserMessage(status=Status.SUCCESS)
with pytest.raises(
ValueError,
match=(
"A user_message should be provided if and only if the status is"
" SUCCESS"
),
):
NextUserMessage(status=Status.TURN_LIMIT_REACHED, user_message=Content())
# these two should not cause exceptions
NextUserMessage(status=Status.SUCCESS, user_message=Content())
NextUserMessage(status=Status.TURN_LIMIT_REACHED)
@@ -0,0 +1,79 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.evaluation import conversation_scenarios
from google.adk.evaluation import eval_case
from google.adk.evaluation import user_simulator_provider
from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator
from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
from google.adk.evaluation.static_user_simulator import StaticUserSimulator
from google.genai import types
import pytest
_TEST_CONVERSATION = [
eval_case.Invocation(
invocation_id='inv1',
user_content=types.Content(parts=[types.Part(text='Hello!')]),
),
]
_TEST_CONVERSATION_SCENARIO = conversation_scenarios.ConversationScenario(
starting_prompt='Hello!', conversation_plan='test plan'
)
class TestUserSimulatorProvider:
"""Test cases for the UserSimulatorProvider."""
def test_provide_static_user_simulator(self):
"""Tests the case when a StaticUserSimulator should be provided."""
provider = user_simulator_provider.UserSimulatorProvider()
test_eval_case = eval_case.EvalCase(
eval_id='test_eval_id',
conversation=_TEST_CONVERSATION,
)
simulator = provider.provide(test_eval_case)
assert isinstance(simulator, StaticUserSimulator)
assert simulator.static_conversation == _TEST_CONVERSATION
def test_provide_llm_backed_user_simulator(self, mocker):
"""Tests the case when a LlmBackedUserSimulator should be provided."""
mock_llm_registry = mocker.patch(
'google.adk.evaluation.llm_backed_user_simulator.LLMRegistry',
autospec=True,
)
mock_llm_registry.return_value.resolve.return_value = mocker.Mock()
# Test case 1: No config in provider.
provider = user_simulator_provider.UserSimulatorProvider()
test_eval_case = eval_case.EvalCase(
eval_id='test_eval_id',
conversation_scenario=_TEST_CONVERSATION_SCENARIO,
)
simulator = provider.provide(test_eval_case)
assert isinstance(simulator, LlmBackedUserSimulator)
assert simulator._conversation_scenario == _TEST_CONVERSATION_SCENARIO
# Test case 2: Config in provider.
llm_config = LlmBackedUserSimulatorConfig(
model='test_model',
)
provider = user_simulator_provider.UserSimulatorProvider(
user_simulator_config=llm_config
)
simulator = provider.provide(test_eval_case)
assert isinstance(simulator, LlmBackedUserSimulator)
assert simulator._conversation_scenario == _TEST_CONVERSATION_SCENARIO
assert simulator._config.model == 'test_model'