You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Update adk eval cli to consume custom metrics by adding CustomMetricEvaluator
Co-authored-by: Joseph Pagadora <jcpagadora@google.com> PiperOrigin-RevId: 857229167
This commit is contained in:
committed by
Copybara-Service
parent
5923da786e
commit
ea0934b993
@@ -34,6 +34,9 @@ from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from ..evaluation.eval_case import get_all_tool_calls
|
||||
from ..evaluation.eval_case import IntermediateDataType
|
||||
from ..evaluation.eval_metrics import EvalMetric
|
||||
from ..evaluation.eval_metrics import Interval
|
||||
from ..evaluation.eval_metrics import MetricInfo
|
||||
from ..evaluation.eval_metrics import MetricValueInfo
|
||||
from ..evaluation.eval_result import EvalCaseResult
|
||||
from ..evaluation.eval_sets_manager import EvalSetsManager
|
||||
from ..utils.context_utils import Aclosing
|
||||
@@ -70,6 +73,19 @@ def _get_agent_module(agent_module_file_path: str):
|
||||
return _import_from_path(module_name, file_path)
|
||||
|
||||
|
||||
def get_default_metric_info(
|
||||
metric_name: str, description: str = ""
|
||||
) -> MetricInfo:
|
||||
"""Returns a default MetricInfo for a metric."""
|
||||
return MetricInfo(
|
||||
metric_name=metric_name,
|
||||
description=description,
|
||||
metric_value_info=MetricValueInfo(
|
||||
interval=Interval(min_value=0.0, max_value=1.0)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_root_agent(agent_module_file_path: str) -> Agent:
|
||||
"""Returns root agent given the agent module."""
|
||||
agent_module = _get_agent_module(agent_module_file_path)
|
||||
|
||||
@@ -712,8 +712,11 @@ def cli_eval(
|
||||
logs.setup_adk_logger(getattr(logging, log_level.upper()))
|
||||
|
||||
try:
|
||||
import importlib
|
||||
|
||||
from ..evaluation.base_eval_service import InferenceConfig
|
||||
from ..evaluation.base_eval_service import InferenceRequest
|
||||
from ..evaluation.custom_metric_evaluator import _CustomMetricEvaluator
|
||||
from ..evaluation.eval_config import get_eval_metrics_from_config
|
||||
from ..evaluation.eval_config import get_evaluation_criteria_or_default
|
||||
from ..evaluation.eval_result import EvalCaseResult
|
||||
@@ -723,9 +726,11 @@ def cli_eval(
|
||||
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
||||
from ..evaluation.local_eval_sets_manager import load_eval_set_from_file
|
||||
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
||||
from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
|
||||
from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider
|
||||
from .cli_eval import _collect_eval_results
|
||||
from .cli_eval import _collect_inferences
|
||||
from .cli_eval import get_default_metric_info
|
||||
from .cli_eval import get_root_agent
|
||||
from .cli_eval import parse_and_get_evals_to_run
|
||||
from .cli_eval import pretty_print_eval_result
|
||||
@@ -818,11 +823,30 @@ def cli_eval(
|
||||
)
|
||||
|
||||
try:
|
||||
metric_evaluator_registry = DEFAULT_METRIC_EVALUATOR_REGISTRY
|
||||
if eval_config.custom_metrics:
|
||||
for (
|
||||
metric_name,
|
||||
config,
|
||||
) in eval_config.custom_metrics.items():
|
||||
if config.metric_info:
|
||||
metric_info = config.metric_info.model_copy()
|
||||
metric_info.metric_name = metric_name
|
||||
else:
|
||||
metric_info = get_default_metric_info(
|
||||
metric_name=metric_name, description=config.description
|
||||
)
|
||||
|
||||
metric_evaluator_registry.register_evaluator(
|
||||
metric_info, _CustomMetricEvaluator
|
||||
)
|
||||
|
||||
eval_service = LocalEvalService(
|
||||
root_agent=root_agent,
|
||||
eval_sets_manager=eval_sets_manager,
|
||||
eval_set_results_manager=eval_set_results_manager,
|
||||
user_simulator_provider=user_simulator_provider,
|
||||
metric_evaluator_registry=metric_evaluator_registry,
|
||||
)
|
||||
|
||||
inference_results = asyncio.run(
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .eval_case import ConversationScenario
|
||||
from .eval_case import Invocation
|
||||
from .eval_metrics import EvalMetric
|
||||
from .eval_metrics import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
|
||||
|
||||
def _get_metric_function(
|
||||
custom_function_path: str,
|
||||
) -> Callable[..., EvaluationResult]:
|
||||
"""Returns the custom metric function from the given path."""
|
||||
try:
|
||||
module_name, function_name = custom_function_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
metric_function = getattr(module, function_name)
|
||||
return metric_function
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(
|
||||
f"Could not import custom metric function from {custom_function_path}"
|
||||
) from e
|
||||
|
||||
|
||||
def _get_eval_status(score: Optional[float], threshold: float) -> EvalStatus:
|
||||
if score is None:
|
||||
return EvalStatus.NOT_EVALUATED
|
||||
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
|
||||
|
||||
|
||||
class _CustomMetricEvaluator(Evaluator):
|
||||
"""Evaluator for custom metrics."""
|
||||
|
||||
def __init__(self, eval_metric: EvalMetric, custom_function_path: str):
|
||||
self._eval_metric = eval_metric
|
||||
self._metric_function = _get_metric_function(custom_function_path)
|
||||
|
||||
@override
|
||||
async def evaluate_invocations(
|
||||
self,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: Optional[list[Invocation]],
|
||||
conversation_scenario: Optional[ConversationScenario] = None,
|
||||
) -> EvaluationResult:
|
||||
if inspect.iscoroutinefunction(self._metric_function):
|
||||
eval_result = await self._metric_function(
|
||||
actual_invocations, expected_invocations, conversation_scenario
|
||||
)
|
||||
else:
|
||||
eval_result = self._metric_function(
|
||||
actual_invocations, expected_invocations, conversation_scenario
|
||||
)
|
||||
|
||||
eval_result.overall_eval_status = _get_eval_status(
|
||||
eval_result.overall_score, self._eval_metric.threshold
|
||||
)
|
||||
return eval_result
|
||||
@@ -28,12 +28,46 @@ from pydantic import model_validator
|
||||
from ..agents.common_configs import CodeConfig
|
||||
from ..evaluation.eval_metrics import EvalMetric
|
||||
from .eval_metrics import BaseCriterion
|
||||
from .eval_metrics import MetricInfo
|
||||
from .eval_metrics import Threshold
|
||||
from .simulation.user_simulator import BaseUserSimulatorConfig
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class CustomMetricConfig(BaseModel):
|
||||
"""Configuration for a custom metric."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=alias_generators.to_camel,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
code_config: CodeConfig = Field(
|
||||
description=(
|
||||
"Code config for the custom metric, used to locate the custom metric"
|
||||
" function."
|
||||
)
|
||||
)
|
||||
metric_info: Optional[MetricInfo] = Field(
|
||||
default=None,
|
||||
description="Metric info for the custom metric.",
|
||||
)
|
||||
description: str = Field(
|
||||
default="",
|
||||
description="Description for the custom metric info.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_code_config_args(self) -> "CustomMetricConfig":
|
||||
"""Checks that the code config does not have args."""
|
||||
if self.code_config.args:
|
||||
raise ValueError(
|
||||
"args field in CodeConfig for custom metric is not supported."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class EvalConfig(BaseModel):
|
||||
"""Configurations needed to run an Eval.
|
||||
|
||||
@@ -74,24 +108,43 @@ the third one uses `LlmAsAJudgeCriterion`.
|
||||
""",
|
||||
)
|
||||
|
||||
custom_metrics: Optional[dict[str, CodeConfig]] = Field(
|
||||
custom_metrics: Optional[dict[str, CustomMetricConfig]] = Field(
|
||||
default=None,
|
||||
description="""A dictionary mapping custom metric names to CodeConfig
|
||||
objects, which specify the path to the function for each custom metric.
|
||||
description="""A dictionary mapping custom metric names to
|
||||
a CustomMetricConfig object.
|
||||
|
||||
If a metric name in `criteria` is also present in `custom_metrics`, the
|
||||
corresponding `CodeConfig`'s `name` field will be used to locate the custom
|
||||
metric implementation. The `name` field should contain the fully qualified
|
||||
path to the custom metric function, e.g., `my.custom.metrics.metric_function`.
|
||||
`code_config` in `CustomMetricConfig` will be used to locate the custom metric
|
||||
implementation.
|
||||
|
||||
The `metric` field in `CustomMetricConfig` can be used to provide metric
|
||||
information like `min_value`, `max_value`, and `description`. If `metric`
|
||||
is not provided, a default `MetricInfo` will be created, using
|
||||
`description` from `CustomMetricConfig` if provided, and default values
|
||||
for `min_value` (0.0) and `max_value` (1.0).
|
||||
|
||||
Example:
|
||||
{
|
||||
"criteria": {
|
||||
"my_custom_metric": 0.5
|
||||
"my_custom_metric": 0.5,
|
||||
"my_simple_metric": 0.8
|
||||
},
|
||||
"custom_metrics": {
|
||||
"my_simple_metric": {
|
||||
"code_config": {
|
||||
"name": "path.to.my.simple.metric.function"
|
||||
}
|
||||
},
|
||||
"my_custom_metric": {
|
||||
"name": "path.to.my.custom.metric.function"
|
||||
"code_config": {
|
||||
"name": "path.to.my.custom.metric.function"
|
||||
},
|
||||
"metric": {
|
||||
"metric_name": "my_custom_metric",
|
||||
"min_value": -10.0,
|
||||
"max_value": 10.0,
|
||||
"description": "My custom metric."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,17 +156,6 @@ Example:
|
||||
description="Config to be used by the user simulator.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_custom_metrics_code_config_args(self) -> "EvalConfig":
|
||||
if self.custom_metrics:
|
||||
for metric_name, metric_config in self.custom_metrics.items():
|
||||
if metric_config.args:
|
||||
raise ValueError(
|
||||
f"args field in CodeConfig for custom metric '{metric_name}' is"
|
||||
" not supported."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
_DEFAULT_EVAL_CONFIG = EvalConfig(
|
||||
criteria={"tool_trajectory_avg_score": 1.0, "response_match_score": 0.8}
|
||||
@@ -144,11 +186,10 @@ def get_eval_metrics_from_config(eval_config: EvalConfig) -> list[EvalMetric]:
|
||||
if eval_config.criteria:
|
||||
for metric_name, criterion in eval_config.criteria.items():
|
||||
custom_function_path = None
|
||||
if (
|
||||
eval_config.custom_metrics
|
||||
and metric_name in eval_config.custom_metrics
|
||||
if eval_config.custom_metrics and (
|
||||
config := eval_config.custom_metrics.get(metric_name)
|
||||
):
|
||||
custom_function_path = eval_config.custom_metrics[metric_name].name
|
||||
custom_function_path = config.code_config.name
|
||||
|
||||
if isinstance(criterion, float):
|
||||
eval_metric_list.append(
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .custom_metric_evaluator import _CustomMetricEvaluator
|
||||
from .eval_metrics import EvalMetric
|
||||
from .eval_metrics import MetricInfo
|
||||
from .eval_metrics import PrebuiltMetrics
|
||||
@@ -62,7 +63,13 @@ class MetricEvaluatorRegistry:
|
||||
if eval_metric.metric_name not in self._registry:
|
||||
raise NotFoundError(f"{eval_metric.metric_name} not found in registry.")
|
||||
|
||||
return self._registry[eval_metric.metric_name][0](eval_metric=eval_metric)
|
||||
evaluator_type = self._registry[eval_metric.metric_name][0]
|
||||
if issubclass(evaluator_type, _CustomMetricEvaluator):
|
||||
return evaluator_type(
|
||||
eval_metric=eval_metric,
|
||||
custom_function_path=eval_metric.custom_function_path,
|
||||
)
|
||||
return evaluator_type(eval_metric=eval_metric)
|
||||
|
||||
def register_evaluator(
|
||||
self,
|
||||
|
||||
@@ -109,8 +109,12 @@ def test_get_eval_metrics_from_config_with_custom_metrics():
|
||||
},
|
||||
},
|
||||
custom_metrics={
|
||||
"custom_metric_1": {"name": "path/to/custom/metric_1"},
|
||||
"custom_metric_2": {"name": "path/to/custom/metric_2"},
|
||||
"custom_metric_1": {
|
||||
"code_config": {"name": "path/to/custom/metric_1"},
|
||||
},
|
||||
"custom_metric_2": {
|
||||
"code_config": {"name": "path/to/custom/metric_2"},
|
||||
},
|
||||
},
|
||||
)
|
||||
eval_metrics = get_eval_metrics_from_config(eval_config)
|
||||
@@ -128,10 +132,12 @@ def test_get_eval_metrics_from_config_with_custom_metrics():
|
||||
|
||||
def test_custom_metric_code_config_with_args_raises_error():
|
||||
with pytest.raises(ValueError):
|
||||
eval_config = EvalConfig(
|
||||
_ = EvalConfig(
|
||||
criteria={"custom_metric": 1.0},
|
||||
custom_metrics={
|
||||
"custom_metric": {"name": "name", "args": [{"value": 1}]}
|
||||
"custom_metric": {
|
||||
"code_config": {"name": "name", "args": [{"value": 1}]},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user