You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add interface between optimization infra and LocalEvalService
details: * Enables the use of ADK evaluations via LocalEvalService for optimizing agents. * Provides flexibility in choosing eval sets and eval cases for training and validation. * Converts ADK eval results into a compact format useful for whitebox agent optimization. Co-authored-by: Keyur Joshi <keyurj@google.com> PiperOrigin-RevId: 875818012
This commit is contained in:
committed by
Copybara-Service
parent
65d9a726c5
commit
7b7ddda46c
@@ -0,0 +1,367 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..evaluation.base_eval_service import EvaluateConfig
|
||||
from ..evaluation.base_eval_service import EvaluateRequest
|
||||
from ..evaluation.base_eval_service import InferenceConfig
|
||||
from ..evaluation.base_eval_service import InferenceRequest
|
||||
from ..evaluation.base_eval_service import InferenceResult
|
||||
from ..evaluation.eval_case import get_all_tool_calls_with_responses
|
||||
from ..evaluation.eval_case import IntermediateData
|
||||
from ..evaluation.eval_case import Invocation
|
||||
from ..evaluation.eval_case import InvocationEvents
|
||||
from ..evaluation.eval_config import EvalConfig
|
||||
from ..evaluation.eval_config import get_eval_metrics_from_config
|
||||
from ..evaluation.eval_metrics import EvalStatus
|
||||
from ..evaluation.eval_result import EvalCaseResult
|
||||
from ..evaluation.eval_sets_manager import EvalSetsManager
|
||||
from ..evaluation.local_eval_service import LocalEvalService
|
||||
from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .data_types import UnstructuredSamplingResult
|
||||
from .sampler import Sampler
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
def _log_eval_summary(eval_results: list[EvalCaseResult]):
|
||||
"""Logs a summary of eval results."""
|
||||
num_pass, num_fail, num_other = 0, 0, 0
|
||||
for eval_result in eval_results:
|
||||
eval_result: EvalCaseResult
|
||||
if eval_result.final_eval_status == EvalStatus.PASSED:
|
||||
num_pass += 1
|
||||
elif eval_result.final_eval_status == EvalStatus.FAILED:
|
||||
num_fail += 1
|
||||
else:
|
||||
num_other += 1
|
||||
log_str = f"Evaluation summary: {num_pass} PASSED, {num_fail} FAILED"
|
||||
if num_other:
|
||||
log_str += f", {num_other} OTHER"
|
||||
logger.info(log_str)
|
||||
|
||||
|
||||
def extract_tool_call_data(
|
||||
intermediate_data: IntermediateData | InvocationEvents,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Extracts tool calls and their responses from intermediate data."""
|
||||
call_response_pairs = get_all_tool_calls_with_responses(intermediate_data)
|
||||
result = []
|
||||
for tool_call, tool_response in call_response_pairs:
|
||||
result.append({
|
||||
"name": tool_call.name,
|
||||
"args": tool_call.args,
|
||||
"response": tool_response.response if tool_response else None,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def extract_single_invocation_info(
|
||||
invocation: Invocation,
|
||||
) -> dict[str, Any]:
|
||||
"""Extracts useful information from a single invocation."""
|
||||
user_prompt = ""
|
||||
for part in invocation.user_content.parts:
|
||||
if part.text and not part.thought:
|
||||
user_prompt += part.text
|
||||
agent_response = ""
|
||||
if invocation.final_response:
|
||||
for part in invocation.final_response.parts:
|
||||
if part.text and not part.thought:
|
||||
agent_response += part.text
|
||||
result = {"user_prompt": user_prompt, "agent_response": agent_response}
|
||||
if invocation.intermediate_data:
|
||||
tool_call_data = extract_tool_call_data(invocation.intermediate_data)
|
||||
result["tool_calls"] = tool_call_data
|
||||
return result
|
||||
|
||||
|
||||
class LocalEvalSamplerConfig(BaseModel):
|
||||
"""Contains configuration options required by the LocalEvalServiceInterface."""
|
||||
|
||||
eval_config: EvalConfig = Field(
|
||||
required=True,
|
||||
description="The configuration for the evaluation.",
|
||||
)
|
||||
|
||||
app_name: str = Field(
|
||||
required=True,
|
||||
description="The app name to use for evaluation.",
|
||||
)
|
||||
|
||||
train_eval_set: str = Field(
|
||||
required=True,
|
||||
description="The name of the eval set to use for optimization.",
|
||||
)
|
||||
|
||||
train_eval_case_ids: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The ids of the eval cases to use for optimization. If not provided,"
|
||||
" all eval cases in the train_eval_set will be used."
|
||||
),
|
||||
)
|
||||
|
||||
validation_eval_set: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The name of the eval set to use for validating the optimized agent."
|
||||
" If not provided, the train_eval_set will also be used for"
|
||||
" validation."
|
||||
),
|
||||
)
|
||||
|
||||
validation_eval_case_ids: Optional[list[str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The ids of the eval cases to use for validating the optimized agent."
|
||||
" If not provided, all eval cases in the validation_eval_set will be"
|
||||
" used. If validation_eval_set is also not provided, all train eval"
|
||||
" cases will be used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LocalEvalSampler(Sampler[UnstructuredSamplingResult]):
|
||||
"""Evaluates candidate agents with the ADK's LocalEvalService."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LocalEvalSamplerConfig,
|
||||
eval_sets_manager: EvalSetsManager,
|
||||
):
|
||||
self._config = config
|
||||
self._eval_sets_manager = eval_sets_manager
|
||||
|
||||
self._train_eval_set = self._config.train_eval_set
|
||||
self._train_eval_case_ids = (
|
||||
self._config.train_eval_case_ids
|
||||
or self._get_eval_case_ids(self._train_eval_set)
|
||||
)
|
||||
|
||||
self._validation_eval_set = (
|
||||
self._config.validation_eval_set or self._train_eval_set
|
||||
)
|
||||
if self._config.validation_eval_case_ids:
|
||||
self._validation_eval_case_ids = self._config.validation_eval_case_ids
|
||||
elif self._config.validation_eval_set:
|
||||
self._validation_eval_case_ids = self._get_eval_case_ids(
|
||||
self._validation_eval_set
|
||||
)
|
||||
else:
|
||||
self._validation_eval_case_ids = self._train_eval_case_ids
|
||||
|
||||
def _get_selected_example_set_id(
|
||||
self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET]
|
||||
) -> str:
|
||||
"""Returns the ID of the selected example set."""
|
||||
return {
|
||||
Sampler.TRAIN_SET: self._train_eval_set,
|
||||
Sampler.VALIDATION_SET: self._validation_eval_set,
|
||||
}[example_set]
|
||||
|
||||
def _get_all_example_ids(
|
||||
self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET]
|
||||
) -> list[str]:
|
||||
"""Returns the IDs of all examples in the selected example set."""
|
||||
return {
|
||||
Sampler.TRAIN_SET: self._train_eval_case_ids,
|
||||
Sampler.VALIDATION_SET: self._validation_eval_case_ids,
|
||||
}[example_set]
|
||||
|
||||
def _get_eval_case_ids(self, eval_set_id: str) -> list[str]:
|
||||
"""Returns the ids of eval cases in the given eval set."""
|
||||
eval_set = self._eval_sets_manager.get_eval_set(
|
||||
app_name=self._config.app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
)
|
||||
if eval_set:
|
||||
return [eval_case.eval_id for eval_case in eval_set.eval_cases]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Eval set `{eval_set_id}` does not exist for app"
|
||||
f" `{self._config.app_name}`."
|
||||
)
|
||||
|
||||
async def _evaluate_agent(
|
||||
self,
|
||||
agent: Agent,
|
||||
eval_set_id: str,
|
||||
eval_case_ids: list[str],
|
||||
) -> list[EvalCaseResult]:
|
||||
"""Evaluates the agent on the requested eval cases and returns the results.
|
||||
|
||||
Args:
|
||||
agent: The agent to evaluate.
|
||||
eval_set_id: The id of the eval set to use for evaluation.
|
||||
eval_case_ids: The ids of the eval cases to use for evaluation.
|
||||
|
||||
Returns:
|
||||
A list of EvalCaseResult, one per eval case.
|
||||
"""
|
||||
# create the inference request
|
||||
inference_request = InferenceRequest(
|
||||
app_name=self._config.app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_case_ids=eval_case_ids,
|
||||
inference_config=InferenceConfig(),
|
||||
)
|
||||
|
||||
# create the LocalEvalService
|
||||
user_simulator_provider = UserSimulatorProvider(
|
||||
self._config.eval_config.user_simulator_config
|
||||
)
|
||||
eval_service = LocalEvalService(
|
||||
root_agent=agent,
|
||||
eval_sets_manager=self._eval_sets_manager,
|
||||
user_simulator_provider=user_simulator_provider,
|
||||
)
|
||||
|
||||
# inference/sampling
|
||||
async with Aclosing(
|
||||
eval_service.perform_inference(inference_request=inference_request)
|
||||
) as agen:
|
||||
inference_results: list[InferenceResult] = [
|
||||
inference_result async for inference_result in agen
|
||||
]
|
||||
|
||||
# evaluation
|
||||
eval_metrics = get_eval_metrics_from_config(self._config.eval_config)
|
||||
evaluate_request = EvaluateRequest(
|
||||
inference_results=inference_results,
|
||||
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
|
||||
)
|
||||
async with Aclosing(
|
||||
eval_service.evaluate(evaluate_request=evaluate_request)
|
||||
) as agen:
|
||||
eval_results: list[EvalCaseResult] = [
|
||||
eval_result async for eval_result in agen
|
||||
]
|
||||
|
||||
return eval_results
|
||||
|
||||
def _extract_eval_data(
|
||||
self,
|
||||
eval_set_id: str,
|
||||
eval_results: list[EvalCaseResult],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Extracts evaluation data from the eval results."""
|
||||
eval_data = {}
|
||||
for eval_result in eval_results:
|
||||
eval_result_dict = {}
|
||||
eval_case = self._eval_sets_manager.get_eval_case(
|
||||
app_name=self._config.app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_case_id=eval_result.eval_id,
|
||||
)
|
||||
if eval_case and eval_case.conversation_scenario:
|
||||
eval_result_dict["conversation_scenario"] = (
|
||||
eval_case.conversation_scenario
|
||||
)
|
||||
|
||||
per_invocation_results = []
|
||||
for (
|
||||
per_invocation_result
|
||||
) in eval_result.eval_metric_result_per_invocation:
|
||||
eval_metric_results = []
|
||||
for eval_metric_result in per_invocation_result.eval_metric_results:
|
||||
eval_metric_results.append({
|
||||
"metric_name": eval_metric_result.metric_name,
|
||||
"score": round(eval_metric_result.score, 2), # accurate enough
|
||||
"eval_status": eval_metric_result.eval_status.name,
|
||||
})
|
||||
per_invocation_result_dict = {
|
||||
"actual_invocation": extract_single_invocation_info(
|
||||
per_invocation_result.actual_invocation
|
||||
),
|
||||
"eval_metric_results": eval_metric_results,
|
||||
}
|
||||
if per_invocation_result.expected_invocation:
|
||||
per_invocation_result_dict["expected_invocation"] = (
|
||||
extract_single_invocation_info(
|
||||
per_invocation_result.expected_invocation
|
||||
)
|
||||
)
|
||||
per_invocation_results.append(per_invocation_result_dict)
|
||||
eval_result_dict["invocations"] = per_invocation_results
|
||||
eval_data[eval_result.eval_id] = eval_result_dict
|
||||
|
||||
return eval_data
|
||||
|
||||
def get_train_example_ids(self) -> list[str]:
|
||||
"""Returns the UIDs of examples to use for training the agent."""
|
||||
return self._train_eval_case_ids
|
||||
|
||||
def get_validation_example_ids(self) -> list[str]:
|
||||
"""Returns the UIDs of examples to use for validating the optimized agent."""
|
||||
return self._validation_eval_case_ids
|
||||
|
||||
async def sample_and_score(
|
||||
self,
|
||||
candidate: Agent,
|
||||
example_set: Literal[
|
||||
Sampler.TRAIN_SET, Sampler.VALIDATION_SET
|
||||
] = Sampler.VALIDATION_SET,
|
||||
batch: Optional[list[str]] = None,
|
||||
capture_full_eval_data: bool = False,
|
||||
) -> UnstructuredSamplingResult:
|
||||
"""Evaluates the candidate agent on the batch of examples using the ADK LocalEvalService.
|
||||
|
||||
Args:
|
||||
candidate: The candidate agent to be evaluated.
|
||||
example_set: The set of examples to evaluate the candidate agent on.
|
||||
Possible values are "train" and "validation".
|
||||
batch: UIDs of examples to evaluate the candidate agent on. If not
|
||||
provided, all examples from the chosen set will be used.
|
||||
capture_full_eval_data: If false, it is enough to only calculate the
|
||||
scores for each example. If true, this method should also capture all
|
||||
other data required for optimizing the agent (e.g., outputs,
|
||||
trajectories, and tool calls).
|
||||
|
||||
Returns:
|
||||
The evaluation results, containing the scores for each example and (if
|
||||
requested) other data required for optimization.
|
||||
"""
|
||||
eval_set_id = self._get_selected_example_set_id(example_set)
|
||||
if batch is None:
|
||||
batch = self._get_all_example_ids(example_set)
|
||||
|
||||
eval_results = await self._evaluate_agent(candidate, eval_set_id, batch)
|
||||
_log_eval_summary(eval_results)
|
||||
|
||||
scores = {
|
||||
eval_result.eval_id: (
|
||||
1.0 if eval_result.final_eval_status == EvalStatus.PASSED else 0.0
|
||||
)
|
||||
for eval_result in eval_results
|
||||
}
|
||||
|
||||
eval_data = (
|
||||
self._extract_eval_data(eval_set_id, eval_results)
|
||||
if capture_full_eval_data
|
||||
else None
|
||||
)
|
||||
|
||||
return UnstructuredSamplingResult(scores=scores, data=eval_data)
|
||||
@@ -32,6 +32,9 @@ class Sampler(ABC, Generic[SamplingResult]):
|
||||
to get evaluation results for the candidate agent on the batch of examples.
|
||||
"""
|
||||
|
||||
TRAIN_SET = "train"
|
||||
VALIDATION_SET = "validation"
|
||||
|
||||
@abstractmethod
|
||||
def get_train_example_ids(self) -> list[str]:
|
||||
"""Returns the UIDs of examples to use for training the agent."""
|
||||
@@ -46,7 +49,7 @@ class Sampler(ABC, Generic[SamplingResult]):
|
||||
async def sample_and_score(
|
||||
self,
|
||||
candidate: Agent,
|
||||
example_set: Literal["train", "validation"] = "validation",
|
||||
example_set: Literal[TRAIN_SET, VALIDATION_SET] = VALIDATION_SET,
|
||||
batch: Optional[list[str]] = None,
|
||||
capture_full_eval_data: bool = False,
|
||||
) -> SamplingResult:
|
||||
|
||||
@@ -0,0 +1,383 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.evaluation.base_eval_service import EvaluateConfig
|
||||
from google.adk.evaluation.base_eval_service import EvaluateRequest
|
||||
from google.adk.evaluation.base_eval_service import InferenceConfig
|
||||
from google.adk.evaluation.base_eval_service import InferenceRequest
|
||||
from google.adk.evaluation.base_eval_service import InferenceResult
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.eval_case import InvocationEvent
|
||||
from google.adk.evaluation.eval_case import InvocationEvents
|
||||
from google.adk.evaluation.eval_config import EvalConfig
|
||||
from google.adk.evaluation.eval_config import EvalMetric
|
||||
from google.adk.evaluation.eval_metrics import EvalMetricResult
|
||||
from google.adk.evaluation.eval_metrics import EvalMetricResultPerInvocation
|
||||
from google.adk.evaluation.eval_metrics import EvalStatus
|
||||
from google.adk.evaluation.eval_result import EvalCaseResult
|
||||
from google.adk.evaluation.eval_sets_manager import EvalSetsManager
|
||||
from google.adk.optimization.local_eval_sampler import _log_eval_summary
|
||||
from google.adk.optimization.local_eval_sampler import extract_single_invocation_info
|
||||
from google.adk.optimization.local_eval_sampler import extract_tool_call_data
|
||||
from google.adk.optimization.local_eval_sampler import LocalEvalSampler
|
||||
from google.adk.optimization.local_eval_sampler import LocalEvalSamplerConfig
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
|
||||
def test_log_eval_summary(mocker):
|
||||
statuses = (
|
||||
[EvalStatus.PASSED] * 3
|
||||
+ [EvalStatus.FAILED] * 2
|
||||
+ [EvalStatus.NOT_EVALUATED]
|
||||
)
|
||||
expected_log = "Evaluation summary: 3 PASSED, 2 FAILED, 1 OTHER"
|
||||
|
||||
eval_results = [
|
||||
mocker.MagicMock(spec=EvalCaseResult, final_eval_status=status)
|
||||
for status in statuses
|
||||
]
|
||||
mock_logger = mocker.patch(
|
||||
"google.adk.optimization.local_eval_sampler.logger"
|
||||
)
|
||||
|
||||
_log_eval_summary(eval_results)
|
||||
|
||||
mock_logger.info.assert_called_once_with(expected_log)
|
||||
|
||||
|
||||
def test_extract_tool_call_data():
|
||||
# omitting IntermediateData tests as it is no longer used
|
||||
# case 1: empty invocation events
|
||||
assert not extract_tool_call_data(InvocationEvents())
|
||||
# case 2: multi call invocation events
|
||||
multi_call_invocation_events = InvocationEvents(
|
||||
invocation_events=[
|
||||
InvocationEvent(
|
||||
author="agent",
|
||||
content=types.Content(
|
||||
parts=[
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
id="call_1",
|
||||
name="tool_1",
|
||||
args={"a": 1},
|
||||
)
|
||||
),
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
id="call_2",
|
||||
name="tool_2",
|
||||
args={"b": 2},
|
||||
)
|
||||
),
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
id="call_1",
|
||||
name="tool_1",
|
||||
response={"result_1": "done"},
|
||||
)
|
||||
),
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
id="call_2",
|
||||
name="tool_2",
|
||||
response={"result_2": "done"},
|
||||
)
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
expected_entries = [
|
||||
{
|
||||
"name": "tool_1",
|
||||
"args": {"a": 1},
|
||||
"response": {"result_1": "done"},
|
||||
},
|
||||
{
|
||||
"name": "tool_2",
|
||||
"args": {"b": 2},
|
||||
"response": {"result_2": "done"},
|
||||
},
|
||||
]
|
||||
result = extract_tool_call_data(multi_call_invocation_events)
|
||||
# order is not guaranteed
|
||||
for expected_entry in expected_entries:
|
||||
assert expected_entry in result
|
||||
assert len(result) == len(expected_entries)
|
||||
|
||||
|
||||
def test_extract_single_invocation_info():
|
||||
invocation = Invocation(
|
||||
user_content=types.Content(
|
||||
parts=[
|
||||
types.Part(text="user thought", thought=True),
|
||||
types.Part(text="Hello agent!"),
|
||||
]
|
||||
),
|
||||
final_response=types.Content(
|
||||
parts=[
|
||||
types.Part(text="agent thought", thought=True),
|
||||
types.Part(text="Hello user!"),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
result = extract_single_invocation_info(invocation)
|
||||
|
||||
assert result == {
|
||||
"user_prompt": "Hello agent!",
|
||||
"agent_response": "Hello user!",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_kwargs, expected_attrs",
|
||||
[
|
||||
(
|
||||
{"train_eval_set": "train_set"},
|
||||
{
|
||||
"_train_eval_set": "train_set",
|
||||
"_train_eval_case_ids": ["train_set_1", "train_set_2"],
|
||||
"_validation_eval_set": "train_set",
|
||||
"_validation_eval_case_ids": ["train_set_1", "train_set_2"],
|
||||
},
|
||||
),
|
||||
(
|
||||
{"train_eval_set": "train_set", "train_eval_case_ids": ["t1"]},
|
||||
{
|
||||
"_train_eval_case_ids": ["t1"],
|
||||
"_validation_eval_case_ids": ["t1"],
|
||||
},
|
||||
),
|
||||
(
|
||||
{"train_eval_set": "train_set", "validation_eval_set": "val_set"},
|
||||
{
|
||||
"_validation_eval_set": "val_set",
|
||||
"_validation_eval_case_ids": ["val_set_1", "val_set_2"],
|
||||
},
|
||||
),
|
||||
(
|
||||
{"train_eval_set": "train_set", "validation_eval_case_ids": ["v1"]},
|
||||
{
|
||||
"_validation_eval_case_ids": ["v1"],
|
||||
},
|
||||
),
|
||||
(
|
||||
{
|
||||
"train_eval_set": "train_set",
|
||||
"train_eval_case_ids": ["t1"],
|
||||
"validation_eval_set": "val_set",
|
||||
"validation_eval_case_ids": ["v1"],
|
||||
},
|
||||
{
|
||||
"_train_eval_case_ids": ["t1"],
|
||||
"_validation_eval_set": "val_set",
|
||||
"_validation_eval_case_ids": ["v1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_local_eval_service_interface_init(
|
||||
mocker, config_kwargs, expected_attrs
|
||||
):
|
||||
mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager)
|
||||
|
||||
def mock_get_eval_case_ids(self, eval_set_id):
|
||||
return [f"{eval_set_id}_1", f"{eval_set_id}_2"]
|
||||
|
||||
mocker.patch.object(
|
||||
LocalEvalSampler,
|
||||
"_get_eval_case_ids",
|
||||
autospec=True,
|
||||
side_effect=mock_get_eval_case_ids,
|
||||
)
|
||||
|
||||
config = LocalEvalSamplerConfig(
|
||||
eval_config=EvalConfig(), app_name="test_app", **config_kwargs
|
||||
)
|
||||
interface = LocalEvalSampler(config, mock_eval_sets_manager)
|
||||
|
||||
for attr, expected_value in expected_attrs.items():
|
||||
assert getattr(interface, attr) == expected_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_agent(mocker):
|
||||
# Mocking LocalEvalService and its methods
|
||||
mock_eval_service_cls = mocker.patch(
|
||||
"google.adk.optimization.local_eval_sampler.LocalEvalService"
|
||||
)
|
||||
mock_eval_service = mock_eval_service_cls.return_value
|
||||
|
||||
# mocking inference
|
||||
mock_inference_result = mocker.MagicMock(spec=InferenceResult)
|
||||
|
||||
async def mock_perform_inference(*args, **kwargs):
|
||||
yield mock_inference_result
|
||||
|
||||
mock_eval_service.perform_inference.side_effect = mock_perform_inference
|
||||
|
||||
# mocking evaluate
|
||||
mock_eval_case_result = mocker.MagicMock(spec=EvalCaseResult)
|
||||
|
||||
async def mock_evaluate(*args, **kwargs):
|
||||
yield mock_eval_case_result
|
||||
|
||||
mock_eval_service.evaluate.side_effect = mock_evaluate
|
||||
|
||||
# mocking get_eval_metrics_from_config
|
||||
mock_metrics = [EvalMetric(metric_name="test_metric")]
|
||||
mocker.patch(
|
||||
"google.adk.optimization.local_eval_sampler.get_eval_metrics_from_config",
|
||||
return_value=mock_metrics,
|
||||
)
|
||||
|
||||
mocker.patch("google.adk.evaluation.base_eval_service.EvaluateConfig")
|
||||
|
||||
# Initialize Interface
|
||||
config = LocalEvalSamplerConfig(
|
||||
eval_config=EvalConfig(),
|
||||
app_name="test_app",
|
||||
train_eval_set="train_set",
|
||||
train_eval_case_ids=["t1"],
|
||||
)
|
||||
interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager))
|
||||
|
||||
# Call _evaluate_agent
|
||||
results = await interface._evaluate_agent(
|
||||
mocker.MagicMock(spec=Agent), "train_set", ["t1"]
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_eval_service.perform_inference.assert_called_once_with(
|
||||
inference_request=InferenceRequest(
|
||||
app_name="test_app",
|
||||
eval_set_id="train_set",
|
||||
eval_case_ids=["t1"],
|
||||
inference_config=InferenceConfig(),
|
||||
)
|
||||
)
|
||||
mock_eval_service.evaluate.assert_called_once_with(
|
||||
evaluate_request=EvaluateRequest(
|
||||
inference_results=[mock_inference_result],
|
||||
evaluate_config=EvaluateConfig(eval_metrics=mock_metrics),
|
||||
)
|
||||
)
|
||||
assert results == [mock_eval_case_result]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_eval_data(mocker):
|
||||
# Mock components
|
||||
mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager)
|
||||
mock_eval_case = mocker.MagicMock()
|
||||
mock_eval_case.conversation_scenario = "test_scenario"
|
||||
mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case
|
||||
|
||||
# Mock per invocation result
|
||||
mock_actual_invocation = mocker.MagicMock(spec=Invocation)
|
||||
mock_expected_invocation = mocker.MagicMock(spec=Invocation)
|
||||
mock_metric_result = mocker.MagicMock(spec=EvalMetricResult)
|
||||
mock_metric_result.metric_name = "test_metric"
|
||||
mock_metric_result.score = 0.854 # should be rounded to 0.85
|
||||
mock_metric_result.eval_status = EvalStatus.PASSED
|
||||
|
||||
mock_per_inv_result = mocker.MagicMock(spec=EvalMetricResultPerInvocation)
|
||||
mock_per_inv_result.actual_invocation = mock_actual_invocation
|
||||
mock_per_inv_result.expected_invocation = mock_expected_invocation
|
||||
mock_per_inv_result.eval_metric_results = [mock_metric_result]
|
||||
|
||||
mock_eval_result = mocker.MagicMock(spec=EvalCaseResult)
|
||||
mock_eval_result.eval_id = "t1"
|
||||
mock_eval_result.eval_metric_result_per_invocation = [mock_per_inv_result]
|
||||
|
||||
# Mock extract_single_invocation_info
|
||||
mocker.patch(
|
||||
"google.adk.optimization.local_eval_sampler.extract_single_invocation_info",
|
||||
side_effect=[{"info": "actual"}, {"info": "expected"}],
|
||||
)
|
||||
|
||||
# Initialize Interface
|
||||
config = LocalEvalSamplerConfig(
|
||||
eval_config=EvalConfig(),
|
||||
app_name="test_app",
|
||||
train_eval_set="train_set",
|
||||
train_eval_case_ids=["t1"],
|
||||
)
|
||||
interface = LocalEvalSampler(config, mock_eval_sets_manager)
|
||||
|
||||
# Call _extract_eval_data
|
||||
eval_data = interface._extract_eval_data("train_set", [mock_eval_result])
|
||||
|
||||
# Assertions
|
||||
assert "t1" in eval_data
|
||||
assert eval_data["t1"]["conversation_scenario"] == "test_scenario"
|
||||
assert len(eval_data["t1"]["invocations"]) == 1
|
||||
inv = eval_data["t1"]["invocations"][0]
|
||||
assert inv["actual_invocation"] == {"info": "actual"}
|
||||
assert inv["expected_invocation"] == {"info": "expected"}
|
||||
assert inv["eval_metric_results"] == [
|
||||
{"metric_name": "test_metric", "score": 0.85, "eval_status": "PASSED"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sample_and_score(mocker):
|
||||
# Mock results
|
||||
mock_eval_result_1 = mocker.MagicMock(spec=EvalCaseResult)
|
||||
mock_eval_result_1.eval_id = "t1"
|
||||
mock_eval_result_1.final_eval_status = EvalStatus.PASSED
|
||||
|
||||
mock_eval_result_2 = mocker.MagicMock(spec=EvalCaseResult)
|
||||
mock_eval_result_2.eval_id = "t2"
|
||||
mock_eval_result_2.final_eval_status = EvalStatus.FAILED
|
||||
|
||||
eval_results = [mock_eval_result_1, mock_eval_result_2]
|
||||
|
||||
# Initialize Interface
|
||||
config = LocalEvalSamplerConfig(
|
||||
eval_config=EvalConfig(),
|
||||
app_name="test_app",
|
||||
train_eval_set="train_set",
|
||||
train_eval_case_ids=["t1", "t2"],
|
||||
)
|
||||
interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager))
|
||||
|
||||
# Patch internal methods
|
||||
mocker.patch.object(interface, "_evaluate_agent", return_value=eval_results)
|
||||
mock_log_summary = mocker.patch(
|
||||
"google.adk.optimization.local_eval_sampler._log_eval_summary"
|
||||
)
|
||||
mock_extract_data = mocker.patch.object(
|
||||
interface, "_extract_eval_data", return_value={"t1": {}, "t2": {}}
|
||||
)
|
||||
|
||||
# Call sample_and_score
|
||||
result = await interface.sample_and_score(
|
||||
mocker.MagicMock(spec=Agent),
|
||||
example_set="train",
|
||||
capture_full_eval_data=True,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result.scores == {"t1": 1.0, "t2": 0.0}
|
||||
assert result.data == {"t1": {}, "t2": {}}
|
||||
mock_log_summary.assert_called_once_with(eval_results)
|
||||
mock_extract_data.assert_called_once_with("train_set", eval_results)
|
||||
Reference in New Issue
Block a user