You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Adds Static User Simulator and User Simulator Provider
Details: - Adds the `StaticUserSimulator` which implements the current functionality of supplying a fixed set of user prompts for an EvalCase. - Adds the `UserSimulatorProvider` which determines the type of user simulator required for an EvalCase (StaticUserSimulator or LlmBackedUserSimulator). - Integrates the UserSimulatorProvider and UserSimulator into the CLI and evaluation infrastructure. - Updates and adds unit tests for the new functionality. - Miscellaneous updates to lay groundwork for a full implementation of the LlmBackedUserSimulator in the future. PiperOrigin-RevId: 822198401
This commit is contained in:
committed by
Copybara-Service
parent
4a842c5a13
commit
aeaec859bf
@@ -549,6 +549,7 @@ def cli_eval(
|
||||
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
||||
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
|
||||
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
||||
from ..evaluation.user_simulator_provider import UserSimulatorProvider
|
||||
from .cli_eval import _collect_eval_results
|
||||
from .cli_eval import _collect_inferences
|
||||
from .cli_eval import get_root_agent
|
||||
@@ -638,11 +639,16 @@ def cli_eval(
|
||||
)
|
||||
)
|
||||
|
||||
user_simulator_provider = UserSimulatorProvider(
|
||||
user_simulator_config=eval_config.user_simulator_config
|
||||
)
|
||||
|
||||
try:
|
||||
eval_service = LocalEvalService(
|
||||
root_agent=root_agent,
|
||||
eval_sets_manager=eval_sets_manager,
|
||||
eval_set_results_manager=eval_set_results_manager,
|
||||
user_simulator_provider=user_simulator_provider,
|
||||
)
|
||||
|
||||
inference_results = asyncio.run(
|
||||
|
||||
@@ -50,6 +50,7 @@ from .eval_sets_manager import EvalSetsManager
|
||||
from .evaluator import EvalStatus
|
||||
from .in_memory_eval_sets_manager import InMemoryEvalSetsManager
|
||||
from .local_eval_sets_manager import convert_eval_set_to_pydanctic_schema
|
||||
from .user_simulator_provider import UserSimulatorProvider
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
@@ -149,12 +150,17 @@ class AgentEvaluator:
|
||||
)
|
||||
eval_metrics = get_eval_metrics_from_config(eval_config)
|
||||
|
||||
user_simulator_provider = UserSimulatorProvider(
|
||||
user_simulator_config=eval_config.user_simulator_config
|
||||
)
|
||||
|
||||
# Step 1: Perform evals, basically inferencing and evaluation of metrics
|
||||
eval_results_by_eval_id = await AgentEvaluator._get_eval_results_by_eval_id(
|
||||
agent_for_eval=agent_for_eval,
|
||||
eval_set=eval_set,
|
||||
eval_metrics=eval_metrics,
|
||||
num_runs=num_runs,
|
||||
user_simulator_provider=user_simulator_provider,
|
||||
)
|
||||
|
||||
# Step 2: Post-process the results!
|
||||
@@ -518,6 +524,7 @@ class AgentEvaluator:
|
||||
eval_set: EvalSet,
|
||||
eval_metrics: list[EvalMetric],
|
||||
num_runs: int,
|
||||
user_simulator_provider: UserSimulatorProvider,
|
||||
) -> dict[str, list[EvalCaseResult]]:
|
||||
"""Returns EvalCaseResults grouped by eval case id.
|
||||
|
||||
@@ -541,6 +548,7 @@ class AgentEvaluator:
|
||||
eval_sets_manager=AgentEvaluator._get_eval_sets_manager(
|
||||
app_name=app_name, eval_set=eval_set
|
||||
),
|
||||
user_simulator_provider=user_simulator_provider,
|
||||
)
|
||||
|
||||
inference_requests = [
|
||||
|
||||
@@ -14,22 +14,19 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import Field
|
||||
|
||||
from .common import EvalBaseModel
|
||||
|
||||
|
||||
class ConversationScenario(EvalBaseModel):
|
||||
"""Scenario for a conversation between a simulated user and the Agent."""
|
||||
"""Scenario for a conversation between a simulated user and the Agent under test."""
|
||||
|
||||
starting_prompt: Union[str, genai_types.Content]
|
||||
starting_prompt: str
|
||||
"""Starting prompt for the conversation.
|
||||
|
||||
This prompt acts as the first user message that is given to the Agent. Any
|
||||
subsequent user messages are obtained by the system that is simulating the
|
||||
This prompt acts as the fixed first user message that is given to the Agent.
|
||||
Any subsequent user messages are obtained by the system that is simulating the
|
||||
user.
|
||||
"""
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Union
|
||||
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .app_details import AppDetails
|
||||
@@ -121,7 +122,7 @@ class SessionInput(EvalBaseModel):
|
||||
|
||||
|
||||
StaticConversation: TypeAlias = list[Invocation]
|
||||
"""A conversation where user's query for each invocation is already specified."""
|
||||
"""A conversation where the user's queries for each invocation are already specified."""
|
||||
|
||||
|
||||
class EvalCase(EvalBaseModel):
|
||||
@@ -158,6 +159,15 @@ class EvalCase(EvalBaseModel):
|
||||
)
|
||||
"""A list of rubrics that are applicable to all the invocations in the conversation of this eval case."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_conversation_xor_conversation_scenario(self) -> EvalCase:
|
||||
if (self.conversation is None) == (self.conversation_scenario is None):
|
||||
raise ValueError(
|
||||
"Exactly one of conversation and conversation_scenario must be"
|
||||
" provided in an EvalCase."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def get_all_tool_calls(
|
||||
intermediate_data: Optional[IntermediateDataType],
|
||||
|
||||
@@ -73,7 +73,7 @@ the third one uses `LlmAsAJudgeCriterion`.
|
||||
|
||||
user_simulator_config: Optional[BaseUserSimulatorConfig] = Field(
|
||||
default=None,
|
||||
description="""Config to be used by the user simulator.""",
|
||||
description="Config to be used by the user simulator.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai.types import Content
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.llm_agent import Agent
|
||||
@@ -41,6 +44,9 @@ from .eval_case import InvocationEvents
|
||||
from .eval_case import SessionInput
|
||||
from .eval_set import EvalSet
|
||||
from .request_intercepter_plugin import _RequestIntercepterPlugin
|
||||
from .user_simulator import Status as UserSimulatorStatus
|
||||
from .user_simulator import UserSimulator
|
||||
from .user_simulator_provider import UserSimulatorProvider
|
||||
|
||||
_USER_AUTHOR = "user"
|
||||
_DEFAULT_AUTHOR = "agent"
|
||||
@@ -79,11 +85,13 @@ class EvaluationGenerator:
|
||||
results = []
|
||||
|
||||
for eval_case in eval_set.eval_cases:
|
||||
# assume only static conversations are needed
|
||||
user_simulator = UserSimulatorProvider().provide(eval_case)
|
||||
responses = []
|
||||
for _ in range(repeat_num):
|
||||
response_invocations = await EvaluationGenerator._process_query(
|
||||
eval_case.conversation,
|
||||
agent_module_path,
|
||||
user_simulator,
|
||||
agent_name,
|
||||
eval_case.session_input,
|
||||
)
|
||||
@@ -123,8 +131,8 @@ class EvaluationGenerator:
|
||||
|
||||
@staticmethod
|
||||
async def _process_query(
|
||||
invocations: list[Invocation],
|
||||
module_name: str,
|
||||
user_simulator: UserSimulator,
|
||||
agent_name: Optional[str] = None,
|
||||
initial_session: Optional[SessionInput] = None,
|
||||
) -> list[Invocation]:
|
||||
@@ -141,13 +149,44 @@ class EvaluationGenerator:
|
||||
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
|
||||
|
||||
return await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
invocations, agent_to_evaluate, reset_func, initial_session
|
||||
agent_to_evaluate,
|
||||
user_simulator=user_simulator,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _generate_inferences_for_single_user_invocation(
|
||||
runner: Runner,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
user_content: Content,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
invocation_id = None
|
||||
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=user_content,
|
||||
)
|
||||
) as agen:
|
||||
|
||||
async for event in agen:
|
||||
if not invocation_id:
|
||||
invocation_id = event.invocation_id
|
||||
yield Event(
|
||||
content=user_content,
|
||||
author=_USER_AUTHOR,
|
||||
invocation_id=invocation_id,
|
||||
)
|
||||
|
||||
yield event
|
||||
|
||||
@staticmethod
|
||||
async def _generate_inferences_from_root_agent(
|
||||
invocations: list[Invocation],
|
||||
root_agent: Agent,
|
||||
user_simulator: UserSimulator,
|
||||
reset_func: Optional[Any] = None,
|
||||
initial_session: Optional[SessionInput] = None,
|
||||
session_id: Optional[str] = None,
|
||||
@@ -155,7 +194,8 @@ class EvaluationGenerator:
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
memory_service: Optional[BaseMemoryService] = None,
|
||||
) -> list[Invocation]:
|
||||
"""Scrapes the root agent given the list of Invocations."""
|
||||
"""Scrapes the root agent in coordination with the user simulator."""
|
||||
|
||||
if not session_service:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
@@ -194,29 +234,19 @@ class EvaluationGenerator:
|
||||
plugins=[request_intercepter_plugin],
|
||||
) as runner:
|
||||
events = []
|
||||
|
||||
for invocation in invocations:
|
||||
user_content = invocation.user_content
|
||||
invocation_id = None
|
||||
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=user_id, session_id=session_id, new_message=user_content
|
||||
)
|
||||
) as agen:
|
||||
|
||||
async for event in agen:
|
||||
if not invocation_id:
|
||||
invocation_id = event.invocation_id
|
||||
events.append(
|
||||
Event(
|
||||
content=user_content,
|
||||
author=_USER_AUTHOR,
|
||||
invocation_id=invocation_id,
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
next_user_message = await user_simulator.get_next_user_message(
|
||||
copy.deepcopy(events)
|
||||
)
|
||||
if next_user_message.status == UserSimulatorStatus.SUCCESS:
|
||||
async for (
|
||||
event
|
||||
) in EvaluationGenerator._generate_inferences_for_single_user_invocation(
|
||||
runner, user_id, session_id, next_user_message.user_message
|
||||
):
|
||||
events.append(event)
|
||||
else: # no message generated
|
||||
break
|
||||
|
||||
app_details_by_invocation_id = (
|
||||
EvaluationGenerator._get_app_details_by_invocation_id(
|
||||
|
||||
@@ -22,7 +22,9 @@ from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..models.registry import LLMRegistry
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .conversation_scenarios import ConversationScenario
|
||||
from .evaluator import Evaluator
|
||||
from .user_simulator import BaseUserSimulatorConfig
|
||||
from .user_simulator import NextUserMessage
|
||||
@@ -37,52 +39,59 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig):
|
||||
description="The model to use for user simulation.",
|
||||
)
|
||||
|
||||
model_config: Optional[genai_types.GenerateContentConfig] = Field(
|
||||
default=genai_types.GenerateContentConfig,
|
||||
model_configuration: genai_types.GenerateContentConfig = Field(
|
||||
default_factory=genai_types.GenerateContentConfig,
|
||||
description="The configuration for the model.",
|
||||
)
|
||||
|
||||
max_allowed_invocations: int = Field(
|
||||
default=20,
|
||||
description="""Maximum number of invocations allowed by the simulated
|
||||
interaction. This property allows us to stop a run-off conversation, where the
|
||||
agent and the user simulator get into an never ending loop.
|
||||
interaction. This property allows us to stop a run-off conversation, where the
|
||||
agent and the user simulator get into a never ending loop. The initial fixed
|
||||
prompt is also counted as an invocation.
|
||||
|
||||
(Not recommended)If you don't want a limit, you can set the value to -1.
|
||||
""",
|
||||
(Not recommended) If you don't want a limit, you can set the value to -1.""",
|
||||
)
|
||||
|
||||
|
||||
@experimental
|
||||
class LlmBackedUserSimulator(UserSimulator):
|
||||
"""A UserSimulator that uses a LLM to generate messages on behalf of the user."""
|
||||
"""A UserSimulator that uses an LLM to generate messages on behalf of the user."""
|
||||
|
||||
config_type: ClassVar[type[LlmBackedUserSimulatorConfig]] = (
|
||||
LlmBackedUserSimulatorConfig
|
||||
)
|
||||
|
||||
def __init__(self, *, config: BaseUserSimulatorConfig):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: BaseUserSimulatorConfig,
|
||||
conversation_scenario: ConversationScenario,
|
||||
):
|
||||
super().__init__(config, config_type=LlmBackedUserSimulator.config_type)
|
||||
self._conversation_scenario = conversation_scenario
|
||||
|
||||
@override
|
||||
async def get_next_user_message(
|
||||
self,
|
||||
conversation_plan: str,
|
||||
events: list[Event],
|
||||
) -> NextUserMessage:
|
||||
"""Returns the next user message to send to the agent with help from a LLM.
|
||||
|
||||
Args:
|
||||
conversation_plan: A plan that user simulation system needs to follow as
|
||||
it plays out the conversation.
|
||||
events: The unaltered conversation history between the user and the
|
||||
agent(s) under evaluation.
|
||||
|
||||
Returns:
|
||||
A NextUserMessage object containing the next user message to send to the
|
||||
agent, or a status indicating why no message was generated.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@override
|
||||
def get_simulation_evaluator(
|
||||
self,
|
||||
) -> Evaluator:
|
||||
) -> Optional[Evaluator]:
|
||||
"""Returns an Evaluator that evaluates if the simulation was successful or not."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -22,6 +22,8 @@ from typing import Callable
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents.base_agent import BaseAgent
|
||||
@@ -51,6 +53,7 @@ from .evaluator import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
|
||||
from .metric_evaluator_registry import MetricEvaluatorRegistry
|
||||
from .user_simulator_provider import UserSimulatorProvider
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
@@ -74,6 +77,7 @@ class LocalEvalService(BaseEvalService):
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
|
||||
session_id_supplier: Callable[[], str] = _get_session_id,
|
||||
user_simulator_provider: UserSimulatorProvider = UserSimulatorProvider(),
|
||||
):
|
||||
self._root_agent = root_agent
|
||||
self._eval_sets_manager = eval_sets_manager
|
||||
@@ -87,6 +91,7 @@ class LocalEvalService(BaseEvalService):
|
||||
self._artifact_service = artifact_service
|
||||
self._eval_set_results_manager = eval_set_results_manager
|
||||
self._session_id_supplier = session_id_supplier
|
||||
self._user_simulator_provider = user_simulator_provider
|
||||
|
||||
@override
|
||||
async def perform_inference(
|
||||
@@ -211,6 +216,48 @@ class LocalEvalService(BaseEvalService):
|
||||
# would be the score for the eval case.
|
||||
overall_eval_metric_results = []
|
||||
|
||||
user_id = (
|
||||
eval_case.session_input.user_id
|
||||
if eval_case.session_input and eval_case.session_input.user_id
|
||||
else 'test_user_id'
|
||||
)
|
||||
|
||||
if eval_case.conversation_scenario:
|
||||
logger.warning(
|
||||
'Skipping evaluation of variable-length conversation scenario in eval'
|
||||
' set/case %s/%s.',
|
||||
inference_result.eval_set_id,
|
||||
inference_result.eval_case_id,
|
||||
)
|
||||
for actual_invocation in inference_result.inferences:
|
||||
eval_metric_result_per_invocation.append(
|
||||
EvalMetricResultPerInvocation(
|
||||
actual_invocation=actual_invocation,
|
||||
expected_invocation=Invocation(
|
||||
user_content=actual_invocation.user_content,
|
||||
final_response=Content(
|
||||
parts=[Part(text='N/A')], role='model'
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
eval_case_result = EvalCaseResult(
|
||||
eval_set_file=inference_result.eval_set_id,
|
||||
eval_set_id=inference_result.eval_set_id,
|
||||
eval_id=inference_result.eval_case_id,
|
||||
final_eval_status=EvalStatus.NOT_EVALUATED,
|
||||
overall_eval_metric_results=overall_eval_metric_results,
|
||||
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
|
||||
session_id=inference_result.session_id,
|
||||
session_details=await self._session_service.get_session(
|
||||
app_name=inference_result.app_name,
|
||||
user_id=user_id,
|
||||
session_id=inference_result.session_id,
|
||||
),
|
||||
user_id=user_id,
|
||||
)
|
||||
return (inference_result, eval_case_result)
|
||||
|
||||
if len(inference_result.inferences) != len(eval_case.conversation):
|
||||
raise ValueError(
|
||||
'Inferences should match conversations in eval case. Found'
|
||||
@@ -281,11 +328,6 @@ class LocalEvalService(BaseEvalService):
|
||||
final_eval_status = self._generate_final_eval_status(
|
||||
overall_eval_metric_results
|
||||
)
|
||||
user_id = (
|
||||
eval_case.session_input.user_id
|
||||
if eval_case.session_input and eval_case.session_input.user_id
|
||||
else 'test_user_id'
|
||||
)
|
||||
|
||||
eval_case_result = EvalCaseResult(
|
||||
eval_set_file=inference_result.eval_set_id,
|
||||
@@ -373,8 +415,8 @@ class LocalEvalService(BaseEvalService):
|
||||
try:
|
||||
inferences = (
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
invocations=eval_case.conversation,
|
||||
root_agent=root_agent,
|
||||
user_simulator=self._user_simulator_provider.provide(eval_case),
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=self._session_service,
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .eval_case import StaticConversation
|
||||
from .evaluator import Evaluator
|
||||
from .user_simulator import BaseUserSimulatorConfig
|
||||
from .user_simulator import NextUserMessage
|
||||
from .user_simulator import Status
|
||||
from .user_simulator import UserSimulator
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class StaticUserSimulator(UserSimulator):
|
||||
"""A UserSimulator that returns a static list of user messages."""
|
||||
|
||||
def __init__(self, *, static_conversation: StaticConversation):
|
||||
super().__init__(
|
||||
BaseUserSimulatorConfig(), config_type=BaseUserSimulatorConfig
|
||||
)
|
||||
self.static_conversation = static_conversation
|
||||
self.invocation_idx = 0
|
||||
|
||||
@override
|
||||
async def get_next_user_message(
|
||||
self,
|
||||
events: list[Event],
|
||||
) -> NextUserMessage:
|
||||
"""Returns the next user message to send to the agent from a static list.
|
||||
|
||||
Args:
|
||||
events: The unaltered conversation history between the user and the
|
||||
agent(s) under evaluation.
|
||||
|
||||
Returns:
|
||||
A NextUserMessage object containing the next user message to send to the
|
||||
agent, or a status indicating why no message was generated.
|
||||
"""
|
||||
# check if we have reached the end of the list of invocations
|
||||
if self.invocation_idx >= len(self.static_conversation):
|
||||
return NextUserMessage(status=Status.STOP_SIGNAL_DETECTED)
|
||||
|
||||
# return the next message in the static list
|
||||
next_user_content = self.static_conversation[
|
||||
self.invocation_idx
|
||||
].user_content
|
||||
self.invocation_idx += 1
|
||||
return NextUserMessage(
|
||||
status=Status.SUCCESS,
|
||||
user_message=next_user_content,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_simulation_evaluator(
|
||||
self,
|
||||
) -> Optional[Evaluator]:
|
||||
"""The StaticUserSimulator does not require an evaluator."""
|
||||
return None
|
||||
@@ -23,6 +23,7 @@ from pydantic import alias_generators
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ..events.event import Event
|
||||
@@ -32,7 +33,7 @@ from .evaluator import Evaluator
|
||||
|
||||
|
||||
class BaseUserSimulatorConfig(BaseModel):
|
||||
"""Base class for configurations pertaining to User Simulator."""
|
||||
"""Base class for configurations pertaining to user simulator."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=alias_generators.to_camel,
|
||||
@@ -60,13 +61,25 @@ not able to do so."""
|
||||
)
|
||||
|
||||
user_message: Optional[genai_types.Content] = Field(
|
||||
description="""The next user message."""
|
||||
description="""The next user message.""", default=None
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_user_message_iff_success(self) -> NextUserMessage:
|
||||
if (self.status == Status.SUCCESS) == (self.user_message is None):
|
||||
raise ValueError(
|
||||
"A user_message should be provided if and only if the status is"
|
||||
" SUCCESS"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@experimental
|
||||
class UserSimulator(ABC):
|
||||
"""A user simulator for the purposes of automating interaction with an Agent."""
|
||||
"""A user simulator for the purposes of automating interaction with an Agent.
|
||||
|
||||
Typically, you must create one user simulator instance per eval case.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -82,21 +95,22 @@ class UserSimulator(ABC):
|
||||
|
||||
async def get_next_user_message(
|
||||
self,
|
||||
conversation_plan: str,
|
||||
events: list[Event],
|
||||
) -> NextUserMessage:
|
||||
"""Returns the next user message to send to the agent.
|
||||
|
||||
Args:
|
||||
conversation_plan: A plan that user simulation system needs to follow as
|
||||
it plays out the conversation.
|
||||
events: The unaltered conversation history between the user and the
|
||||
agent(s) under evaluation.
|
||||
|
||||
Returns:
|
||||
A NextUserMessage object containing the next user message to send to the
|
||||
agent, or a status indicating why no message was generated.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_simulation_evaluator(
|
||||
self,
|
||||
) -> Evaluator:
|
||||
"""Returns an instnace of an Evaluator that evaluates if the simulation was successful or not."""
|
||||
) -> Optional[Evaluator]:
|
||||
"""Returns an instance of an Evaluator that evaluates if the user simulation was successful or not."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .eval_case import EvalCase
|
||||
from .llm_backed_user_simulator import LlmBackedUserSimulator
|
||||
from .static_user_simulator import StaticUserSimulator
|
||||
from .user_simulator import BaseUserSimulatorConfig
|
||||
from .user_simulator import UserSimulator
|
||||
|
||||
|
||||
@experimental
|
||||
class UserSimulatorProvider:
|
||||
"""Provides a UserSimulator instance per EvalCase, mixing configuration data
|
||||
from the EvalConfig with per-EvalCase conversation data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_simulator_config: Optional[BaseUserSimulatorConfig] = None,
|
||||
):
|
||||
if user_simulator_config is None:
|
||||
user_simulator_config = BaseUserSimulatorConfig()
|
||||
elif not isinstance(user_simulator_config, BaseUserSimulatorConfig):
|
||||
# assume that the user simulator will fully validate the config it gets.
|
||||
raise ValueError(f"Expect config of type `{BaseUserSimulatorConfig}`.")
|
||||
self._user_simulator_config = user_simulator_config
|
||||
|
||||
def provide(self, eval_case: EvalCase) -> UserSimulator:
|
||||
"""Provide an appropriate user simulator based on the type of conversation data in the EvalCase
|
||||
|
||||
Args:
|
||||
eval_case: An EvalCase containing a `conversation` xor a
|
||||
`conversation_scenario`.
|
||||
|
||||
Returns:
|
||||
A StaticUserSimulator or an LlmBackedUserSimulator based on the type of
|
||||
the conversation data.
|
||||
|
||||
Raises:
|
||||
ValueError: If no conversation data or multiple types of conversation data
|
||||
are provided.
|
||||
"""
|
||||
if eval_case.conversation is None:
|
||||
if eval_case.conversation_scenario is None:
|
||||
raise ValueError(
|
||||
"Neither static invocations nor conversation scenario provided in"
|
||||
" EvalCase. Provide exactly one."
|
||||
)
|
||||
|
||||
return LlmBackedUserSimulator(
|
||||
config=self._user_simulator_config,
|
||||
conversation_scenario=eval_case.conversation_scenario,
|
||||
)
|
||||
|
||||
else: # eval_case.conversation is not None
|
||||
if eval_case.conversation_scenario is not None:
|
||||
raise ValueError(
|
||||
"Both static invocations and conversation scenario provided in"
|
||||
" EvalCase. Provide exactly one."
|
||||
)
|
||||
|
||||
return StaticUserSimulator(static_conversation=eval_case.conversation)
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.evaluation.conversation_scenarios import ConversationScenario
|
||||
from google.adk.evaluation.eval_case import EvalCase
|
||||
from google.adk.evaluation.eval_case import get_all_tool_calls
|
||||
from google.adk.evaluation.eval_case import get_all_tool_calls_with_responses
|
||||
from google.adk.evaluation.eval_case import get_all_tool_responses
|
||||
@@ -246,3 +248,36 @@ def test_get_all_tool_calls_with_responses_with_invocation_events():
|
||||
(tool_call1, tool_response1),
|
||||
(tool_call2, None),
|
||||
]
|
||||
|
||||
|
||||
def test_conversation_and_conversation_scenario_mutual_exclusion():
|
||||
"""Tests the ensure_conversation_xor_conversation_scenario validator."""
|
||||
test_conversation_scenario = ConversationScenario(
|
||||
starting_prompt='', conversation_plan=''
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
'Exactly one of conversation and conversation_scenario must be'
|
||||
' provided in an EvalCase.'
|
||||
),
|
||||
):
|
||||
EvalCase(eval_id='test_id')
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
'Exactly one of conversation and conversation_scenario must be'
|
||||
' provided in an EvalCase.'
|
||||
),
|
||||
):
|
||||
EvalCase(
|
||||
eval_id='test_id',
|
||||
conversation=[],
|
||||
conversation_scenario=test_conversation_scenario,
|
||||
)
|
||||
|
||||
# these two should not cause exceptions
|
||||
EvalCase(eval_id='test_id', conversation=[])
|
||||
EvalCase(eval_id='test_id', conversation_scenario=test_conversation_scenario)
|
||||
|
||||
@@ -18,9 +18,13 @@ from google.adk.evaluation.app_details import AgentDetails
|
||||
from google.adk.evaluation.app_details import AppDetails
|
||||
from google.adk.evaluation.evaluation_generator import EvaluationGenerator
|
||||
from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin
|
||||
from google.adk.evaluation.user_simulator import NextUserMessage
|
||||
from google.adk.evaluation.user_simulator import Status as UserSimulatorStatus
|
||||
from google.adk.evaluation.user_simulator import UserSimulator
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
def _build_event(
|
||||
@@ -324,3 +328,130 @@ class TestGetAppDetailsByInvocationId:
|
||||
}
|
||||
assert app_details == expected_app_details
|
||||
assert mock_request_intercepter.get_model_request.call_count == 3
|
||||
|
||||
|
||||
class TestGenerateInferencesForSingleUserInvocation:
|
||||
"""Test cases for EvaluationGenerator._generate_inferences_for_single_user_invocation method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_inferences_with_mock_runner(self, mocker):
|
||||
"""Tests inference generation with a mocked runner."""
|
||||
runner = mocker.MagicMock()
|
||||
|
||||
agent_parts = [types.Part(text="Agent response")]
|
||||
|
||||
async def mock_run_async(*args, **kwargs):
|
||||
yield _build_event(
|
||||
author="agent",
|
||||
parts=agent_parts,
|
||||
invocation_id="inv1",
|
||||
)
|
||||
|
||||
runner.run_async.return_value = mock_run_async()
|
||||
|
||||
user_content = types.Content(parts=[types.Part(text="User query")])
|
||||
|
||||
events = [
|
||||
event
|
||||
async for event in EvaluationGenerator._generate_inferences_for_single_user_invocation(
|
||||
runner, "test_user", "test_session", user_content
|
||||
)
|
||||
]
|
||||
|
||||
assert len(events) == 2
|
||||
assert events[0].author == "user"
|
||||
assert events[0].content == user_content
|
||||
assert events[0].invocation_id == "inv1"
|
||||
assert events[1].author == "agent"
|
||||
assert events[1].content.parts == agent_parts
|
||||
|
||||
runner.run_async.assert_called_once_with(
|
||||
user_id="test_user",
|
||||
session_id="test_session",
|
||||
new_message=user_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(mocker):
|
||||
"""Provides a mock Runner for testing."""
|
||||
mock_runner_cls = mocker.patch(
|
||||
"google.adk.evaluation.evaluation_generator.Runner"
|
||||
)
|
||||
mock_runner_instance = mocker.AsyncMock()
|
||||
mock_runner_instance.__aenter__.return_value = mock_runner_instance
|
||||
mock_runner_cls.return_value = mock_runner_instance
|
||||
yield mock_runner_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_service(mocker):
|
||||
"""Provides a mock InMemorySessionService for testing."""
|
||||
mock_session_service_cls = mocker.patch(
|
||||
"google.adk.evaluation.evaluation_generator.InMemorySessionService"
|
||||
)
|
||||
mock_session_service_instance = mocker.MagicMock()
|
||||
mock_session_service_instance.create_session = mocker.AsyncMock()
|
||||
mock_session_service_cls.return_value = mock_session_service_instance
|
||||
yield mock_session_service_instance
|
||||
|
||||
|
||||
class TestGenerateInferencesFromRootAgent:
|
||||
"""Test cases for EvaluationGenerator._generate_inferences_from_root_agent method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generates_inferences_with_user_simulator(
|
||||
self, mocker, mock_runner, mock_session_service
|
||||
):
|
||||
"""Tests that inferences are generated by interacting with a user simulator."""
|
||||
mock_agent = mocker.MagicMock()
|
||||
mock_user_sim = mocker.MagicMock(spec=UserSimulator)
|
||||
|
||||
# Mock user simulator will produce one message, then stop.
|
||||
async def get_next_user_message_side_effect(*args, **kwargs):
|
||||
if mock_user_sim.get_next_user_message.call_count == 1:
|
||||
return NextUserMessage(
|
||||
status=UserSimulatorStatus.SUCCESS,
|
||||
user_message=types.Content(parts=[types.Part(text="message 1")]),
|
||||
)
|
||||
return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED)
|
||||
|
||||
mock_user_sim.get_next_user_message = mocker.AsyncMock(
|
||||
side_effect=get_next_user_message_side_effect
|
||||
)
|
||||
|
||||
mock_generate_inferences = mocker.patch(
|
||||
"google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_for_single_user_invocation"
|
||||
)
|
||||
mocker.patch(
|
||||
"google.adk.evaluation.evaluation_generator.EvaluationGenerator._get_app_details_by_invocation_id"
|
||||
)
|
||||
mocker.patch(
|
||||
"google.adk.evaluation.evaluation_generator.EvaluationGenerator.convert_events_to_eval_invocations"
|
||||
)
|
||||
|
||||
# Each call to _generate_inferences_for_single_user_invocation will
|
||||
# yield one user and one agent event.
|
||||
async def mock_generate_inferences_side_effect(
|
||||
runner, user_id, session_id, user_content
|
||||
):
|
||||
yield _build_event("user", user_content.parts, "inv1")
|
||||
yield _build_event("agent", [types.Part(text="agent_response")], "inv1")
|
||||
|
||||
mock_generate_inferences.side_effect = mock_generate_inferences_side_effect
|
||||
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
root_agent=mock_agent,
|
||||
user_simulator=mock_user_sim,
|
||||
)
|
||||
|
||||
# Check that user simulator was called until it stopped.
|
||||
assert mock_user_sim.get_next_user_message.call_count == 2
|
||||
|
||||
# Check that we generated inferences for each user message.
|
||||
assert mock_generate_inferences.call_count == 1
|
||||
|
||||
# Check the content of the user messages passed to inference generation
|
||||
mock_generate_inferences.assert_called_once()
|
||||
called_with_content = mock_generate_inferences.call_args.args[3]
|
||||
assert called_with_content.parts[0].text == "message 1"
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
@@ -246,6 +248,7 @@ async def test_evaluate_success(
|
||||
|
||||
mock_eval_case = mocker.MagicMock(spec=EvalCase)
|
||||
mock_eval_case.conversation = []
|
||||
mock_eval_case.conversation_scenario = None
|
||||
mock_eval_case.session_input = None
|
||||
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
|
||||
|
||||
@@ -321,6 +324,7 @@ async def test_evaluate_single_inference_result(
|
||||
invocation.model_copy(deep=True),
|
||||
invocation.model_copy(deep=True),
|
||||
]
|
||||
mock_eval_case.conversation_scenario = None
|
||||
mock_eval_case.session_input = None
|
||||
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
|
||||
|
||||
@@ -352,6 +356,51 @@ async def test_evaluate_single_inference_result(
|
||||
assert metric_result.eval_status == EvalStatus.PASSED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_single_inference_result_skipped_for_conversation_scenario(
|
||||
eval_service, mock_eval_sets_manager, mocker
|
||||
):
|
||||
"""To be removed once evaluation is implemented for conversation scenarios."""
|
||||
invocation = Invocation(
|
||||
user_content=genai_types.Content(
|
||||
parts=[genai_types.Part(text="test user content.")]
|
||||
),
|
||||
final_response=genai_types.Content(
|
||||
parts=[genai_types.Part(text="test final response.")]
|
||||
),
|
||||
)
|
||||
inference_result = InferenceResult(
|
||||
app_name="test_app",
|
||||
eval_set_id="test_eval_set",
|
||||
eval_case_id="case1",
|
||||
inferences=[invocation.model_copy(deep=True)],
|
||||
session_id="session1",
|
||||
)
|
||||
eval_metric = EvalMetric(metric_name="fake_metric", threshold=0.5)
|
||||
evaluate_config = EvaluateConfig(eval_metrics=[eval_metric], parallelism=1)
|
||||
|
||||
mock_eval_case = mocker.MagicMock(spec=EvalCase)
|
||||
mock_eval_case.conversation = None
|
||||
mock_eval_case.conversation_scenario = mocker.MagicMock()
|
||||
mock_eval_case.session_input = None
|
||||
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
|
||||
|
||||
_, result = await eval_service._evaluate_single_inference_result(
|
||||
inference_result=inference_result, evaluate_config=evaluate_config
|
||||
)
|
||||
assert isinstance(result, EvalCaseResult)
|
||||
assert result.eval_id == "case1"
|
||||
assert result.final_eval_status == EvalStatus.NOT_EVALUATED
|
||||
assert not result.overall_eval_metric_results
|
||||
assert len(result.eval_metric_result_per_invocation) == 1
|
||||
invocation_result = result.eval_metric_result_per_invocation[0]
|
||||
assert not invocation_result.eval_metric_results
|
||||
assert (
|
||||
invocation_result.expected_invocation.final_response.parts[0].text
|
||||
== "N/A"
|
||||
)
|
||||
|
||||
|
||||
def test_generate_final_eval_status_doesn_t_throw_on(eval_service):
|
||||
# How to fix if this test case fails?
|
||||
# This test case has failed mainly because a new EvalStatus got added. You
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.evaluation import static_user_simulator
|
||||
from google.adk.evaluation import user_simulator
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
class TestStaticUserSimulator:
|
||||
"""Test cases for StaticUserSimulator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_next_user_message(self):
|
||||
"""Tests that the provided messages are returned in order followed by the stop signal."""
|
||||
conversation = [
|
||||
Invocation(
|
||||
invocation_id="inv1",
|
||||
user_content=types.Content(parts=[types.Part(text="message 1")]),
|
||||
),
|
||||
Invocation(
|
||||
invocation_id="inv2",
|
||||
user_content=types.Content(parts=[types.Part(text="message 2")]),
|
||||
),
|
||||
]
|
||||
simulator = static_user_simulator.StaticUserSimulator(
|
||||
static_conversation=conversation
|
||||
)
|
||||
|
||||
next_message_1 = await simulator.get_next_user_message(events=[])
|
||||
assert user_simulator.Status.SUCCESS == next_message_1.status
|
||||
assert "message 1" == next_message_1.user_message.parts[0].text
|
||||
|
||||
next_message_2 = await simulator.get_next_user_message(events=[])
|
||||
assert user_simulator.Status.SUCCESS == next_message_2.status
|
||||
assert "message 2" == next_message_2.user_message.parts[0].text
|
||||
|
||||
next_message_3 = await simulator.get_next_user_message(events=[])
|
||||
assert user_simulator.Status.STOP_SIGNAL_DETECTED == next_message_3.status
|
||||
assert next_message_3.user_message is None
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.evaluation.user_simulator import NextUserMessage
|
||||
from google.adk.evaluation.user_simulator import Status
|
||||
from google.genai.types import Content
|
||||
import pytest
|
||||
|
||||
|
||||
def test_next_user_message_validation():
|
||||
"""Tests post-init validation of NextUserMessage."""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"A user_message should be provided if and only if the status is"
|
||||
" SUCCESS"
|
||||
),
|
||||
):
|
||||
NextUserMessage(status=Status.SUCCESS)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"A user_message should be provided if and only if the status is"
|
||||
" SUCCESS"
|
||||
),
|
||||
):
|
||||
NextUserMessage(status=Status.TURN_LIMIT_REACHED, user_message=Content())
|
||||
|
||||
# these two should not cause exceptions
|
||||
NextUserMessage(status=Status.SUCCESS, user_message=Content())
|
||||
NextUserMessage(status=Status.TURN_LIMIT_REACHED)
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.evaluation import conversation_scenarios
|
||||
from google.adk.evaluation import eval_case
|
||||
from google.adk.evaluation import user_simulator_provider
|
||||
from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator
|
||||
from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig
|
||||
from google.adk.evaluation.static_user_simulator import StaticUserSimulator
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
_TEST_CONVERSATION = [
|
||||
eval_case.Invocation(
|
||||
invocation_id='inv1',
|
||||
user_content=types.Content(parts=[types.Part(text='Hello!')]),
|
||||
),
|
||||
]
|
||||
|
||||
_TEST_CONVERSATION_SCENARIO = conversation_scenarios.ConversationScenario(
|
||||
starting_prompt='Hello!', conversation_plan='test plan'
|
||||
)
|
||||
|
||||
|
||||
class TestUserSimulatorProvider:
|
||||
"""Test cases for the UserSimulatorProvider."""
|
||||
|
||||
def test_provide_static_user_simulator(self):
|
||||
"""Tests the case when a StaticUserSimulator should be provided."""
|
||||
provider = user_simulator_provider.UserSimulatorProvider()
|
||||
test_eval_case = eval_case.EvalCase(
|
||||
eval_id='test_eval_id',
|
||||
conversation=_TEST_CONVERSATION,
|
||||
)
|
||||
simulator = provider.provide(test_eval_case)
|
||||
assert isinstance(simulator, StaticUserSimulator)
|
||||
assert simulator.static_conversation == _TEST_CONVERSATION
|
||||
|
||||
def test_provide_llm_backed_user_simulator(self, mocker):
|
||||
"""Tests the case when a LlmBackedUserSimulator should be provided."""
|
||||
mock_llm_registry = mocker.patch(
|
||||
'google.adk.evaluation.llm_backed_user_simulator.LLMRegistry',
|
||||
autospec=True,
|
||||
)
|
||||
mock_llm_registry.return_value.resolve.return_value = mocker.Mock()
|
||||
# Test case 1: No config in provider.
|
||||
provider = user_simulator_provider.UserSimulatorProvider()
|
||||
test_eval_case = eval_case.EvalCase(
|
||||
eval_id='test_eval_id',
|
||||
conversation_scenario=_TEST_CONVERSATION_SCENARIO,
|
||||
)
|
||||
simulator = provider.provide(test_eval_case)
|
||||
assert isinstance(simulator, LlmBackedUserSimulator)
|
||||
assert simulator._conversation_scenario == _TEST_CONVERSATION_SCENARIO
|
||||
|
||||
# Test case 2: Config in provider.
|
||||
llm_config = LlmBackedUserSimulatorConfig(
|
||||
model='test_model',
|
||||
)
|
||||
provider = user_simulator_provider.UserSimulatorProvider(
|
||||
user_simulator_config=llm_config
|
||||
)
|
||||
simulator = provider.provide(test_eval_case)
|
||||
assert isinstance(simulator, LlmBackedUserSimulator)
|
||||
assert simulator._conversation_scenario == _TEST_CONVERSATION_SCENARIO
|
||||
assert simulator._config.model == 'test_model'
|
||||
Reference in New Issue
Block a user