diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 7115b5fc..a47d20f7 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -549,6 +549,7 @@ def cli_eval( from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import load_eval_set_from_file from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager + from ..evaluation.user_simulator_provider import UserSimulatorProvider from .cli_eval import _collect_eval_results from .cli_eval import _collect_inferences from .cli_eval import get_root_agent @@ -638,11 +639,16 @@ def cli_eval( ) ) + user_simulator_provider = UserSimulatorProvider( + user_simulator_config=eval_config.user_simulator_config + ) + try: eval_service = LocalEvalService( root_agent=root_agent, eval_sets_manager=eval_sets_manager, eval_set_results_manager=eval_set_results_manager, + user_simulator_provider=user_simulator_provider, ) inference_results = asyncio.run( diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index 3788a2e6..ff6666d6 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -50,6 +50,7 @@ from .eval_sets_manager import EvalSetsManager from .evaluator import EvalStatus from .in_memory_eval_sets_manager import InMemoryEvalSetsManager from .local_eval_sets_manager import convert_eval_set_to_pydanctic_schema +from .user_simulator_provider import UserSimulatorProvider logger = logging.getLogger("google_adk." + __name__) @@ -149,12 +150,17 @@ class AgentEvaluator: ) eval_metrics = get_eval_metrics_from_config(eval_config) + user_simulator_provider = UserSimulatorProvider( + user_simulator_config=eval_config.user_simulator_config + ) + # Step 1: Perform evals, basically inferencing and evaluation of metrics eval_results_by_eval_id = await AgentEvaluator._get_eval_results_by_eval_id( agent_for_eval=agent_for_eval, eval_set=eval_set, eval_metrics=eval_metrics, num_runs=num_runs, + user_simulator_provider=user_simulator_provider, ) # Step 2: Post-process the results! @@ -518,6 +524,7 @@ class AgentEvaluator: eval_set: EvalSet, eval_metrics: list[EvalMetric], num_runs: int, + user_simulator_provider: UserSimulatorProvider, ) -> dict[str, list[EvalCaseResult]]: """Returns EvalCaseResults grouped by eval case id. @@ -541,6 +548,7 @@ class AgentEvaluator: eval_sets_manager=AgentEvaluator._get_eval_sets_manager( app_name=app_name, eval_set=eval_set ), + user_simulator_provider=user_simulator_provider, ) inference_requests = [ diff --git a/src/google/adk/evaluation/conversation_scenarios.py b/src/google/adk/evaluation/conversation_scenarios.py index 1ad7dd73..fc5d3653 100644 --- a/src/google/adk/evaluation/conversation_scenarios.py +++ b/src/google/adk/evaluation/conversation_scenarios.py @@ -14,22 +14,19 @@ from __future__ import annotations -from typing import Union - -from google.genai import types as genai_types from pydantic import Field from .common import EvalBaseModel class ConversationScenario(EvalBaseModel): - """Scenario for a conversation between a simulated user and the Agent.""" + """Scenario for a conversation between a simulated user and the Agent under test.""" - starting_prompt: Union[str, genai_types.Content] + starting_prompt: str """Starting prompt for the conversation. - This prompt acts as the first user message that is given to the Agent. Any - subsequent user messages are obtained by the system that is simulating the + This prompt acts as the fixed first user message that is given to the Agent. + Any subsequent user messages are obtained by the system that is simulating the user. """ diff --git a/src/google/adk/evaluation/eval_case.py b/src/google/adk/evaluation/eval_case.py index 5e27aa41..9d338901 100644 --- a/src/google/adk/evaluation/eval_case.py +++ b/src/google/adk/evaluation/eval_case.py @@ -20,6 +20,7 @@ from typing import Union from google.genai import types as genai_types from pydantic import Field +from pydantic import model_validator from typing_extensions import TypeAlias from .app_details import AppDetails @@ -121,7 +122,7 @@ class SessionInput(EvalBaseModel): StaticConversation: TypeAlias = list[Invocation] -"""A conversation where user's query for each invocation is already specified.""" +"""A conversation where the user's queries for each invocation are already specified.""" class EvalCase(EvalBaseModel): @@ -158,6 +159,15 @@ class EvalCase(EvalBaseModel): ) """A list of rubrics that are applicable to all the invocations in the conversation of this eval case.""" + @model_validator(mode="after") + def ensure_conversation_xor_conversation_scenario(self) -> EvalCase: + if (self.conversation is None) == (self.conversation_scenario is None): + raise ValueError( + "Exactly one of conversation and conversation_scenario must be" + " provided in an EvalCase." + ) + return self + def get_all_tool_calls( intermediate_data: Optional[IntermediateDataType], diff --git a/src/google/adk/evaluation/eval_config.py b/src/google/adk/evaluation/eval_config.py index e1b1fd6c..8302bb91 100644 --- a/src/google/adk/evaluation/eval_config.py +++ b/src/google/adk/evaluation/eval_config.py @@ -73,7 +73,7 @@ the third one uses `LlmAsAJudgeCriterion`. user_simulator_config: Optional[BaseUserSimulatorConfig] = Field( default=None, - description="""Config to be used by the user simulator.""", + description="Config to be used by the user simulator.", ) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 97986dfd..dff61810 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -14,11 +14,14 @@ from __future__ import annotations +import copy import importlib from typing import Any +from typing import AsyncGenerator from typing import Optional import uuid +from google.genai.types import Content from pydantic import BaseModel from ..agents.llm_agent import Agent @@ -41,6 +44,9 @@ from .eval_case import InvocationEvents from .eval_case import SessionInput from .eval_set import EvalSet from .request_intercepter_plugin import _RequestIntercepterPlugin +from .user_simulator import Status as UserSimulatorStatus +from .user_simulator import UserSimulator +from .user_simulator_provider import UserSimulatorProvider _USER_AUTHOR = "user" _DEFAULT_AUTHOR = "agent" @@ -79,11 +85,13 @@ class EvaluationGenerator: results = [] for eval_case in eval_set.eval_cases: + # assume only static conversations are needed + user_simulator = UserSimulatorProvider().provide(eval_case) responses = [] for _ in range(repeat_num): response_invocations = await EvaluationGenerator._process_query( - eval_case.conversation, agent_module_path, + user_simulator, agent_name, eval_case.session_input, ) @@ -123,8 +131,8 @@ class EvaluationGenerator: @staticmethod async def _process_query( - invocations: list[Invocation], module_name: str, + user_simulator: UserSimulator, agent_name: Optional[str] = None, initial_session: Optional[SessionInput] = None, ) -> list[Invocation]: @@ -141,13 +149,44 @@ class EvaluationGenerator: assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found." return await EvaluationGenerator._generate_inferences_from_root_agent( - invocations, agent_to_evaluate, reset_func, initial_session + agent_to_evaluate, + user_simulator=user_simulator, + reset_func=reset_func, + initial_session=initial_session, ) + @staticmethod + async def _generate_inferences_for_single_user_invocation( + runner: Runner, + user_id: str, + session_id: str, + user_content: Content, + ) -> AsyncGenerator[Event, None]: + invocation_id = None + + async with Aclosing( + runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=user_content, + ) + ) as agen: + + async for event in agen: + if not invocation_id: + invocation_id = event.invocation_id + yield Event( + content=user_content, + author=_USER_AUTHOR, + invocation_id=invocation_id, + ) + + yield event + @staticmethod async def _generate_inferences_from_root_agent( - invocations: list[Invocation], root_agent: Agent, + user_simulator: UserSimulator, reset_func: Optional[Any] = None, initial_session: Optional[SessionInput] = None, session_id: Optional[str] = None, @@ -155,7 +194,8 @@ class EvaluationGenerator: artifact_service: Optional[BaseArtifactService] = None, memory_service: Optional[BaseMemoryService] = None, ) -> list[Invocation]: - """Scrapes the root agent given the list of Invocations.""" + """Scrapes the root agent in coordination with the user simulator.""" + if not session_service: session_service = InMemorySessionService() @@ -194,29 +234,19 @@ class EvaluationGenerator: plugins=[request_intercepter_plugin], ) as runner: events = [] - - for invocation in invocations: - user_content = invocation.user_content - invocation_id = None - - async with Aclosing( - runner.run_async( - user_id=user_id, session_id=session_id, new_message=user_content - ) - ) as agen: - - async for event in agen: - if not invocation_id: - invocation_id = event.invocation_id - events.append( - Event( - content=user_content, - author=_USER_AUTHOR, - invocation_id=invocation_id, - ) - ) - + while True: + next_user_message = await user_simulator.get_next_user_message( + copy.deepcopy(events) + ) + if next_user_message.status == UserSimulatorStatus.SUCCESS: + async for ( + event + ) in EvaluationGenerator._generate_inferences_for_single_user_invocation( + runner, user_id, session_id, next_user_message.user_message + ): events.append(event) + else: # no message generated + break app_details_by_invocation_id = ( EvaluationGenerator._get_app_details_by_invocation_id( diff --git a/src/google/adk/evaluation/llm_backed_user_simulator.py b/src/google/adk/evaluation/llm_backed_user_simulator.py index cda86b93..9993563e 100644 --- a/src/google/adk/evaluation/llm_backed_user_simulator.py +++ b/src/google/adk/evaluation/llm_backed_user_simulator.py @@ -22,7 +22,9 @@ from pydantic import Field from typing_extensions import override from ..events.event import Event +from ..models.registry import LLMRegistry from ..utils.feature_decorator import experimental +from .conversation_scenarios import ConversationScenario from .evaluator import Evaluator from .user_simulator import BaseUserSimulatorConfig from .user_simulator import NextUserMessage @@ -37,52 +39,59 @@ class LlmBackedUserSimulatorConfig(BaseUserSimulatorConfig): description="The model to use for user simulation.", ) - model_config: Optional[genai_types.GenerateContentConfig] = Field( - default=genai_types.GenerateContentConfig, + model_configuration: genai_types.GenerateContentConfig = Field( + default_factory=genai_types.GenerateContentConfig, description="The configuration for the model.", ) max_allowed_invocations: int = Field( default=20, description="""Maximum number of invocations allowed by the simulated -interaction. This property allows us to stop a run-off conversation, where the -agent and the user simulator get into an never ending loop. +interaction. This property allows us to stop a run-off conversation, where the +agent and the user simulator get into a never ending loop. The initial fixed +prompt is also counted as an invocation. -(Not recommended)If you don't want a limit, you can set the value to -1. - """, +(Not recommended) If you don't want a limit, you can set the value to -1.""", ) @experimental class LlmBackedUserSimulator(UserSimulator): - """A UserSimulator that uses a LLM to generate messages on behalf of the user.""" + """A UserSimulator that uses an LLM to generate messages on behalf of the user.""" config_type: ClassVar[type[LlmBackedUserSimulatorConfig]] = ( LlmBackedUserSimulatorConfig ) - def __init__(self, *, config: BaseUserSimulatorConfig): + def __init__( + self, + *, + config: BaseUserSimulatorConfig, + conversation_scenario: ConversationScenario, + ): super().__init__(config, config_type=LlmBackedUserSimulator.config_type) + self._conversation_scenario = conversation_scenario @override async def get_next_user_message( self, - conversation_plan: str, events: list[Event], ) -> NextUserMessage: """Returns the next user message to send to the agent with help from a LLM. Args: - conversation_plan: A plan that user simulation system needs to follow as - it plays out the conversation. events: The unaltered conversation history between the user and the agent(s) under evaluation. + + Returns: + A NextUserMessage object containing the next user message to send to the + agent, or a status indicating why no message was generated. """ raise NotImplementedError() @override def get_simulation_evaluator( self, - ) -> Evaluator: + ) -> Optional[Evaluator]: """Returns an Evaluator that evaluates if the simulation was successful or not.""" raise NotImplementedError() diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 84e26cb1..4dfd391a 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -22,6 +22,8 @@ from typing import Callable from typing import Optional import uuid +from google.genai.types import Content +from google.genai.types import Part from typing_extensions import override from ..agents.base_agent import BaseAgent @@ -51,6 +53,7 @@ from .evaluator import EvalStatus from .evaluator import EvaluationResult from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY from .metric_evaluator_registry import MetricEvaluatorRegistry +from .user_simulator_provider import UserSimulatorProvider logger = logging.getLogger('google_adk.' + __name__) @@ -74,6 +77,7 @@ class LocalEvalService(BaseEvalService): artifact_service: Optional[BaseArtifactService] = None, eval_set_results_manager: Optional[EvalSetResultsManager] = None, session_id_supplier: Callable[[], str] = _get_session_id, + user_simulator_provider: UserSimulatorProvider = UserSimulatorProvider(), ): self._root_agent = root_agent self._eval_sets_manager = eval_sets_manager @@ -87,6 +91,7 @@ class LocalEvalService(BaseEvalService): self._artifact_service = artifact_service self._eval_set_results_manager = eval_set_results_manager self._session_id_supplier = session_id_supplier + self._user_simulator_provider = user_simulator_provider @override async def perform_inference( @@ -211,6 +216,48 @@ class LocalEvalService(BaseEvalService): # would be the score for the eval case. overall_eval_metric_results = [] + user_id = ( + eval_case.session_input.user_id + if eval_case.session_input and eval_case.session_input.user_id + else 'test_user_id' + ) + + if eval_case.conversation_scenario: + logger.warning( + 'Skipping evaluation of variable-length conversation scenario in eval' + ' set/case %s/%s.', + inference_result.eval_set_id, + inference_result.eval_case_id, + ) + for actual_invocation in inference_result.inferences: + eval_metric_result_per_invocation.append( + EvalMetricResultPerInvocation( + actual_invocation=actual_invocation, + expected_invocation=Invocation( + user_content=actual_invocation.user_content, + final_response=Content( + parts=[Part(text='N/A')], role='model' + ), + ), + ) + ) + eval_case_result = EvalCaseResult( + eval_set_file=inference_result.eval_set_id, + eval_set_id=inference_result.eval_set_id, + eval_id=inference_result.eval_case_id, + final_eval_status=EvalStatus.NOT_EVALUATED, + overall_eval_metric_results=overall_eval_metric_results, + eval_metric_result_per_invocation=eval_metric_result_per_invocation, + session_id=inference_result.session_id, + session_details=await self._session_service.get_session( + app_name=inference_result.app_name, + user_id=user_id, + session_id=inference_result.session_id, + ), + user_id=user_id, + ) + return (inference_result, eval_case_result) + if len(inference_result.inferences) != len(eval_case.conversation): raise ValueError( 'Inferences should match conversations in eval case. Found' @@ -281,11 +328,6 @@ class LocalEvalService(BaseEvalService): final_eval_status = self._generate_final_eval_status( overall_eval_metric_results ) - user_id = ( - eval_case.session_input.user_id - if eval_case.session_input and eval_case.session_input.user_id - else 'test_user_id' - ) eval_case_result = EvalCaseResult( eval_set_file=inference_result.eval_set_id, @@ -373,8 +415,8 @@ class LocalEvalService(BaseEvalService): try: inferences = ( await EvaluationGenerator._generate_inferences_from_root_agent( - invocations=eval_case.conversation, root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide(eval_case), initial_session=initial_session, session_id=session_id, session_service=self._session_service, diff --git a/src/google/adk/evaluation/static_user_simulator.py b/src/google/adk/evaluation/static_user_simulator.py new file mode 100644 index 00000000..4c5e2cb5 --- /dev/null +++ b/src/google/adk/evaluation/static_user_simulator.py @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Optional + +from typing_extensions import override + +from ..events.event import Event +from ..utils.feature_decorator import experimental +from .eval_case import StaticConversation +from .evaluator import Evaluator +from .user_simulator import BaseUserSimulatorConfig +from .user_simulator import NextUserMessage +from .user_simulator import Status +from .user_simulator import UserSimulator + +logger = logging.getLogger("google_adk." + __name__) + + +@experimental +class StaticUserSimulator(UserSimulator): + """A UserSimulator that returns a static list of user messages.""" + + def __init__(self, *, static_conversation: StaticConversation): + super().__init__( + BaseUserSimulatorConfig(), config_type=BaseUserSimulatorConfig + ) + self.static_conversation = static_conversation + self.invocation_idx = 0 + + @override + async def get_next_user_message( + self, + events: list[Event], + ) -> NextUserMessage: + """Returns the next user message to send to the agent from a static list. + + Args: + events: The unaltered conversation history between the user and the + agent(s) under evaluation. + + Returns: + A NextUserMessage object containing the next user message to send to the + agent, or a status indicating why no message was generated. + """ + # check if we have reached the end of the list of invocations + if self.invocation_idx >= len(self.static_conversation): + return NextUserMessage(status=Status.STOP_SIGNAL_DETECTED) + + # return the next message in the static list + next_user_content = self.static_conversation[ + self.invocation_idx + ].user_content + self.invocation_idx += 1 + return NextUserMessage( + status=Status.SUCCESS, + user_message=next_user_content, + ) + + @override + def get_simulation_evaluator( + self, + ) -> Optional[Evaluator]: + """The StaticUserSimulator does not require an evaluator.""" + return None diff --git a/src/google/adk/evaluation/user_simulator.py b/src/google/adk/evaluation/user_simulator.py index 39297805..c5ab013d 100644 --- a/src/google/adk/evaluation/user_simulator.py +++ b/src/google/adk/evaluation/user_simulator.py @@ -23,6 +23,7 @@ from pydantic import alias_generators from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator from pydantic import ValidationError from ..events.event import Event @@ -32,7 +33,7 @@ from .evaluator import Evaluator class BaseUserSimulatorConfig(BaseModel): - """Base class for configurations pertaining to User Simulator.""" + """Base class for configurations pertaining to user simulator.""" model_config = ConfigDict( alias_generator=alias_generators.to_camel, @@ -60,13 +61,25 @@ not able to do so.""" ) user_message: Optional[genai_types.Content] = Field( - description="""The next user message.""" + description="""The next user message.""", default=None ) + @model_validator(mode="after") + def ensure_user_message_iff_success(self) -> NextUserMessage: + if (self.status == Status.SUCCESS) == (self.user_message is None): + raise ValueError( + "A user_message should be provided if and only if the status is" + " SUCCESS" + ) + return self + @experimental class UserSimulator(ABC): - """A user simulator for the purposes of automating interaction with an Agent.""" + """A user simulator for the purposes of automating interaction with an Agent. + + Typically, you must create one user simulator instance per eval case. + """ def __init__( self, @@ -82,21 +95,22 @@ class UserSimulator(ABC): async def get_next_user_message( self, - conversation_plan: str, events: list[Event], ) -> NextUserMessage: """Returns the next user message to send to the agent. Args: - conversation_plan: A plan that user simulation system needs to follow as - it plays out the conversation. events: The unaltered conversation history between the user and the agent(s) under evaluation. + + Returns: + A NextUserMessage object containing the next user message to send to the + agent, or a status indicating why no message was generated. """ raise NotImplementedError() def get_simulation_evaluator( self, - ) -> Evaluator: - """Returns an instnace of an Evaluator that evaluates if the simulation was successful or not.""" + ) -> Optional[Evaluator]: + """Returns an instance of an Evaluator that evaluates if the user simulation was successful or not.""" raise NotImplementedError() diff --git a/src/google/adk/evaluation/user_simulator_provider.py b/src/google/adk/evaluation/user_simulator_provider.py new file mode 100644 index 00000000..1aea8c8c --- /dev/null +++ b/src/google/adk/evaluation/user_simulator_provider.py @@ -0,0 +1,77 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +from ..utils.feature_decorator import experimental +from .eval_case import EvalCase +from .llm_backed_user_simulator import LlmBackedUserSimulator +from .static_user_simulator import StaticUserSimulator +from .user_simulator import BaseUserSimulatorConfig +from .user_simulator import UserSimulator + + +@experimental +class UserSimulatorProvider: + """Provides a UserSimulator instance per EvalCase, mixing configuration data + from the EvalConfig with per-EvalCase conversation data.""" + + def __init__( + self, + user_simulator_config: Optional[BaseUserSimulatorConfig] = None, + ): + if user_simulator_config is None: + user_simulator_config = BaseUserSimulatorConfig() + elif not isinstance(user_simulator_config, BaseUserSimulatorConfig): + # assume that the user simulator will fully validate the config it gets. + raise ValueError(f"Expect config of type `{BaseUserSimulatorConfig}`.") + self._user_simulator_config = user_simulator_config + + def provide(self, eval_case: EvalCase) -> UserSimulator: + """Provide an appropriate user simulator based on the type of conversation data in the EvalCase + + Args: + eval_case: An EvalCase containing a `conversation` xor a + `conversation_scenario`. + + Returns: + A StaticUserSimulator or an LlmBackedUserSimulator based on the type of + the conversation data. + + Raises: + ValueError: If no conversation data or multiple types of conversation data + are provided. + """ + if eval_case.conversation is None: + if eval_case.conversation_scenario is None: + raise ValueError( + "Neither static invocations nor conversation scenario provided in" + " EvalCase. Provide exactly one." + ) + + return LlmBackedUserSimulator( + config=self._user_simulator_config, + conversation_scenario=eval_case.conversation_scenario, + ) + + else: # eval_case.conversation is not None + if eval_case.conversation_scenario is not None: + raise ValueError( + "Both static invocations and conversation scenario provided in" + " EvalCase. Provide exactly one." + ) + + return StaticUserSimulator(static_conversation=eval_case.conversation) diff --git a/tests/unittests/evaluation/test_eval_case.py b/tests/unittests/evaluation/test_eval_case.py index bea81d46..4784a9a0 100644 --- a/tests/unittests/evaluation/test_eval_case.py +++ b/tests/unittests/evaluation/test_eval_case.py @@ -14,6 +14,8 @@ from __future__ import annotations +from google.adk.evaluation.conversation_scenarios import ConversationScenario +from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import get_all_tool_calls from google.adk.evaluation.eval_case import get_all_tool_calls_with_responses from google.adk.evaluation.eval_case import get_all_tool_responses @@ -246,3 +248,36 @@ def test_get_all_tool_calls_with_responses_with_invocation_events(): (tool_call1, tool_response1), (tool_call2, None), ] + + +def test_conversation_and_conversation_scenario_mutual_exclusion(): + """Tests the ensure_conversation_xor_conversation_scenario validator.""" + test_conversation_scenario = ConversationScenario( + starting_prompt='', conversation_plan='' + ) + + with pytest.raises( + ValueError, + match=( + 'Exactly one of conversation and conversation_scenario must be' + ' provided in an EvalCase.' + ), + ): + EvalCase(eval_id='test_id') + + with pytest.raises( + ValueError, + match=( + 'Exactly one of conversation and conversation_scenario must be' + ' provided in an EvalCase.' + ), + ): + EvalCase( + eval_id='test_id', + conversation=[], + conversation_scenario=test_conversation_scenario, + ) + + # these two should not cause exceptions + EvalCase(eval_id='test_id', conversation=[]) + EvalCase(eval_id='test_id', conversation_scenario=test_conversation_scenario) diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 133c6187..27372f12 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -18,9 +18,13 @@ from google.adk.evaluation.app_details import AgentDetails from google.adk.evaluation.app_details import AppDetails from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin +from google.adk.evaluation.user_simulator import NextUserMessage +from google.adk.evaluation.user_simulator import Status as UserSimulatorStatus +from google.adk.evaluation.user_simulator import UserSimulator from google.adk.events.event import Event from google.adk.models.llm_request import LlmRequest from google.genai import types +import pytest def _build_event( @@ -324,3 +328,130 @@ class TestGetAppDetailsByInvocationId: } assert app_details == expected_app_details assert mock_request_intercepter.get_model_request.call_count == 3 + + +class TestGenerateInferencesForSingleUserInvocation: + """Test cases for EvaluationGenerator._generate_inferences_for_single_user_invocation method.""" + + @pytest.mark.asyncio + async def test_generate_inferences_with_mock_runner(self, mocker): + """Tests inference generation with a mocked runner.""" + runner = mocker.MagicMock() + + agent_parts = [types.Part(text="Agent response")] + + async def mock_run_async(*args, **kwargs): + yield _build_event( + author="agent", + parts=agent_parts, + invocation_id="inv1", + ) + + runner.run_async.return_value = mock_run_async() + + user_content = types.Content(parts=[types.Part(text="User query")]) + + events = [ + event + async for event in EvaluationGenerator._generate_inferences_for_single_user_invocation( + runner, "test_user", "test_session", user_content + ) + ] + + assert len(events) == 2 + assert events[0].author == "user" + assert events[0].content == user_content + assert events[0].invocation_id == "inv1" + assert events[1].author == "agent" + assert events[1].content.parts == agent_parts + + runner.run_async.assert_called_once_with( + user_id="test_user", + session_id="test_session", + new_message=user_content, + ) + + +@pytest.fixture +def mock_runner(mocker): + """Provides a mock Runner for testing.""" + mock_runner_cls = mocker.patch( + "google.adk.evaluation.evaluation_generator.Runner" + ) + mock_runner_instance = mocker.AsyncMock() + mock_runner_instance.__aenter__.return_value = mock_runner_instance + mock_runner_cls.return_value = mock_runner_instance + yield mock_runner_instance + + +@pytest.fixture +def mock_session_service(mocker): + """Provides a mock InMemorySessionService for testing.""" + mock_session_service_cls = mocker.patch( + "google.adk.evaluation.evaluation_generator.InMemorySessionService" + ) + mock_session_service_instance = mocker.MagicMock() + mock_session_service_instance.create_session = mocker.AsyncMock() + mock_session_service_cls.return_value = mock_session_service_instance + yield mock_session_service_instance + + +class TestGenerateInferencesFromRootAgent: + """Test cases for EvaluationGenerator._generate_inferences_from_root_agent method.""" + + @pytest.mark.asyncio + async def test_generates_inferences_with_user_simulator( + self, mocker, mock_runner, mock_session_service + ): + """Tests that inferences are generated by interacting with a user simulator.""" + mock_agent = mocker.MagicMock() + mock_user_sim = mocker.MagicMock(spec=UserSimulator) + + # Mock user simulator will produce one message, then stop. + async def get_next_user_message_side_effect(*args, **kwargs): + if mock_user_sim.get_next_user_message.call_count == 1: + return NextUserMessage( + status=UserSimulatorStatus.SUCCESS, + user_message=types.Content(parts=[types.Part(text="message 1")]), + ) + return NextUserMessage(status=UserSimulatorStatus.STOP_SIGNAL_DETECTED) + + mock_user_sim.get_next_user_message = mocker.AsyncMock( + side_effect=get_next_user_message_side_effect + ) + + mock_generate_inferences = mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._generate_inferences_for_single_user_invocation" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._get_app_details_by_invocation_id" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator.convert_events_to_eval_invocations" + ) + + # Each call to _generate_inferences_for_single_user_invocation will + # yield one user and one agent event. + async def mock_generate_inferences_side_effect( + runner, user_id, session_id, user_content + ): + yield _build_event("user", user_content.parts, "inv1") + yield _build_event("agent", [types.Part(text="agent_response")], "inv1") + + mock_generate_inferences.side_effect = mock_generate_inferences_side_effect + + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=mock_agent, + user_simulator=mock_user_sim, + ) + + # Check that user simulator was called until it stopped. + assert mock_user_sim.get_next_user_message.call_count == 2 + + # Check that we generated inferences for each user message. + assert mock_generate_inferences.call_count == 1 + + # Check the content of the user messages passed to inference generation + mock_generate_inferences.assert_called_once() + called_with_content = mock_generate_inferences.call_args.args[3] + assert called_with_content.parts[0].text == "message 1" diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index c9010444..d5faf55f 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio import sys @@ -246,6 +248,7 @@ async def test_evaluate_success( mock_eval_case = mocker.MagicMock(spec=EvalCase) mock_eval_case.conversation = [] + mock_eval_case.conversation_scenario = None mock_eval_case.session_input = None mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case @@ -321,6 +324,7 @@ async def test_evaluate_single_inference_result( invocation.model_copy(deep=True), invocation.model_copy(deep=True), ] + mock_eval_case.conversation_scenario = None mock_eval_case.session_input = None mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case @@ -352,6 +356,51 @@ async def test_evaluate_single_inference_result( assert metric_result.eval_status == EvalStatus.PASSED +@pytest.mark.asyncio +async def test_evaluate_single_inference_result_skipped_for_conversation_scenario( + eval_service, mock_eval_sets_manager, mocker +): + """To be removed once evaluation is implemented for conversation scenarios.""" + invocation = Invocation( + user_content=genai_types.Content( + parts=[genai_types.Part(text="test user content.")] + ), + final_response=genai_types.Content( + parts=[genai_types.Part(text="test final response.")] + ), + ) + inference_result = InferenceResult( + app_name="test_app", + eval_set_id="test_eval_set", + eval_case_id="case1", + inferences=[invocation.model_copy(deep=True)], + session_id="session1", + ) + eval_metric = EvalMetric(metric_name="fake_metric", threshold=0.5) + evaluate_config = EvaluateConfig(eval_metrics=[eval_metric], parallelism=1) + + mock_eval_case = mocker.MagicMock(spec=EvalCase) + mock_eval_case.conversation = None + mock_eval_case.conversation_scenario = mocker.MagicMock() + mock_eval_case.session_input = None + mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case + + _, result = await eval_service._evaluate_single_inference_result( + inference_result=inference_result, evaluate_config=evaluate_config + ) + assert isinstance(result, EvalCaseResult) + assert result.eval_id == "case1" + assert result.final_eval_status == EvalStatus.NOT_EVALUATED + assert not result.overall_eval_metric_results + assert len(result.eval_metric_result_per_invocation) == 1 + invocation_result = result.eval_metric_result_per_invocation[0] + assert not invocation_result.eval_metric_results + assert ( + invocation_result.expected_invocation.final_response.parts[0].text + == "N/A" + ) + + def test_generate_final_eval_status_doesn_t_throw_on(eval_service): # How to fix if this test case fails? # This test case has failed mainly because a new EvalStatus got added. You diff --git a/tests/unittests/evaluation/test_static_user_simulator.py b/tests/unittests/evaluation/test_static_user_simulator.py new file mode 100644 index 00000000..5cc70c80 --- /dev/null +++ b/tests/unittests/evaluation/test_static_user_simulator.py @@ -0,0 +1,54 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation import static_user_simulator +from google.adk.evaluation import user_simulator +from google.adk.evaluation.eval_case import Invocation +from google.genai import types +import pytest + + +class TestStaticUserSimulator: + """Test cases for StaticUserSimulator.""" + + @pytest.mark.asyncio + async def test_get_next_user_message(self): + """Tests that the provided messages are returned in order followed by the stop signal.""" + conversation = [ + Invocation( + invocation_id="inv1", + user_content=types.Content(parts=[types.Part(text="message 1")]), + ), + Invocation( + invocation_id="inv2", + user_content=types.Content(parts=[types.Part(text="message 2")]), + ), + ] + simulator = static_user_simulator.StaticUserSimulator( + static_conversation=conversation + ) + + next_message_1 = await simulator.get_next_user_message(events=[]) + assert user_simulator.Status.SUCCESS == next_message_1.status + assert "message 1" == next_message_1.user_message.parts[0].text + + next_message_2 = await simulator.get_next_user_message(events=[]) + assert user_simulator.Status.SUCCESS == next_message_2.status + assert "message 2" == next_message_2.user_message.parts[0].text + + next_message_3 = await simulator.get_next_user_message(events=[]) + assert user_simulator.Status.STOP_SIGNAL_DETECTED == next_message_3.status + assert next_message_3.user_message is None diff --git a/tests/unittests/evaluation/test_user_simulator.py b/tests/unittests/evaluation/test_user_simulator.py new file mode 100644 index 00000000..c3e1e606 --- /dev/null +++ b/tests/unittests/evaluation/test_user_simulator.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation.user_simulator import NextUserMessage +from google.adk.evaluation.user_simulator import Status +from google.genai.types import Content +import pytest + + +def test_next_user_message_validation(): + """Tests post-init validation of NextUserMessage.""" + with pytest.raises( + ValueError, + match=( + "A user_message should be provided if and only if the status is" + " SUCCESS" + ), + ): + NextUserMessage(status=Status.SUCCESS) + + with pytest.raises( + ValueError, + match=( + "A user_message should be provided if and only if the status is" + " SUCCESS" + ), + ): + NextUserMessage(status=Status.TURN_LIMIT_REACHED, user_message=Content()) + + # these two should not cause exceptions + NextUserMessage(status=Status.SUCCESS, user_message=Content()) + NextUserMessage(status=Status.TURN_LIMIT_REACHED) diff --git a/tests/unittests/evaluation/test_user_simulator_provider.py b/tests/unittests/evaluation/test_user_simulator_provider.py new file mode 100644 index 00000000..7cff4241 --- /dev/null +++ b/tests/unittests/evaluation/test_user_simulator_provider.py @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.evaluation import conversation_scenarios +from google.adk.evaluation import eval_case +from google.adk.evaluation import user_simulator_provider +from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator +from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig +from google.adk.evaluation.static_user_simulator import StaticUserSimulator +from google.genai import types +import pytest + +_TEST_CONVERSATION = [ + eval_case.Invocation( + invocation_id='inv1', + user_content=types.Content(parts=[types.Part(text='Hello!')]), + ), +] + +_TEST_CONVERSATION_SCENARIO = conversation_scenarios.ConversationScenario( + starting_prompt='Hello!', conversation_plan='test plan' +) + + +class TestUserSimulatorProvider: + """Test cases for the UserSimulatorProvider.""" + + def test_provide_static_user_simulator(self): + """Tests the case when a StaticUserSimulator should be provided.""" + provider = user_simulator_provider.UserSimulatorProvider() + test_eval_case = eval_case.EvalCase( + eval_id='test_eval_id', + conversation=_TEST_CONVERSATION, + ) + simulator = provider.provide(test_eval_case) + assert isinstance(simulator, StaticUserSimulator) + assert simulator.static_conversation == _TEST_CONVERSATION + + def test_provide_llm_backed_user_simulator(self, mocker): + """Tests the case when a LlmBackedUserSimulator should be provided.""" + mock_llm_registry = mocker.patch( + 'google.adk.evaluation.llm_backed_user_simulator.LLMRegistry', + autospec=True, + ) + mock_llm_registry.return_value.resolve.return_value = mocker.Mock() + # Test case 1: No config in provider. + provider = user_simulator_provider.UserSimulatorProvider() + test_eval_case = eval_case.EvalCase( + eval_id='test_eval_id', + conversation_scenario=_TEST_CONVERSATION_SCENARIO, + ) + simulator = provider.provide(test_eval_case) + assert isinstance(simulator, LlmBackedUserSimulator) + assert simulator._conversation_scenario == _TEST_CONVERSATION_SCENARIO + + # Test case 2: Config in provider. + llm_config = LlmBackedUserSimulatorConfig( + model='test_model', + ) + provider = user_simulator_provider.UserSimulatorProvider( + user_simulator_config=llm_config + ) + simulator = provider.provide(test_eval_case) + assert isinstance(simulator, LlmBackedUserSimulator) + assert simulator._conversation_scenario == _TEST_CONVERSATION_SCENARIO + assert simulator._config.model == 'test_model'