feat: Remove overall_eval_status calculation from _CustomMetricEvaluator and add threshold to custom metric function expected signature

Co-authored-by: Joseph Pagadora <jcpagadora@google.com>
PiperOrigin-RevId: 861268984
This commit is contained in:
Joseph Pagadora
2026-01-26 11:01:08 -08:00
committed by Copybara-Service
parent 85434e293f
commit 553e376718
3 changed files with 124 additions and 15 deletions
@@ -24,7 +24,6 @@ from typing_extensions import override
from .eval_case import ConversationScenario
from .eval_case import Invocation
from .eval_metrics import EvalMetric
from .eval_metrics import EvalStatus
from .evaluator import EvaluationResult
from .evaluator import Evaluator
@@ -44,12 +43,6 @@ def _get_metric_function(
) from e
def _get_eval_status(score: Optional[float], threshold: float) -> EvalStatus:
if score is None:
return EvalStatus.NOT_EVALUATED
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
class _CustomMetricEvaluator(Evaluator):
"""Evaluator for custom metrics."""
@@ -64,16 +57,20 @@ class _CustomMetricEvaluator(Evaluator):
expected_invocations: Optional[list[Invocation]],
conversation_scenario: Optional[ConversationScenario] = None,
) -> EvaluationResult:
eval_metric = self._eval_metric.model_copy(deep=True)
eval_metric.threshold = None
if inspect.iscoroutinefunction(self._metric_function):
eval_result = await self._metric_function(
actual_invocations, expected_invocations, conversation_scenario
eval_metric,
actual_invocations,
expected_invocations,
conversation_scenario,
)
else:
eval_result = self._metric_function(
actual_invocations, expected_invocations, conversation_scenario
eval_metric,
actual_invocations,
expected_invocations,
conversation_scenario,
)
eval_result.overall_eval_status = _get_eval_status(
eval_result.overall_score, self._eval_metric.threshold
)
return eval_result
+4 -2
View File
@@ -258,9 +258,11 @@ class EvalMetric(EvalBaseModel):
description="The name of the metric.",
)
threshold: float = Field(
threshold: Optional[float] = Field(
default=None,
description=(
"A threshold value. Each metric decides how to interpret this"
"This field will be deprecated soon. Please use `criterion` instead."
" A threshold value. Each metric decides how to interpret this"
" threshold."
),
)
@@ -0,0 +1,110 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.evaluation.custom_metric_evaluator import _CustomMetricEvaluator
from google.adk.evaluation.custom_metric_evaluator import _get_metric_function
from google.adk.evaluation.eval_case import ConversationScenario
from google.adk.evaluation.eval_case import Invocation
from google.adk.evaluation.eval_metrics import EvalMetric
from google.adk.evaluation.evaluator import EvaluationResult
import pytest
def my_sync_metric_function(
eval_metric: EvalMetric,
actual_invocations: list[Invocation],
expected_invocations: list[Invocation] | None,
conversation_scenario: ConversationScenario | None,
) -> EvaluationResult:
"""Sync metric function for testing."""
return EvaluationResult(overall_score=1.0)
async def my_async_metric_function(
eval_metric: EvalMetric,
actual_invocations: list[Invocation],
expected_invocations: list[Invocation] | None,
conversation_scenario: ConversationScenario | None,
) -> EvaluationResult:
"""Async metric function for testing."""
return EvaluationResult(overall_score=0.5)
@mock.patch("importlib.import_module")
def test_get_metric_function_success(mock_import_module):
"""Tests that _get_metric_function successfully returns a function."""
mock_module = mock.MagicMock()
mock_module.my_sync_metric_function = my_sync_metric_function
mock_import_module.return_value = mock_module
func = _get_metric_function(
"test_custom_metric_evaluator.my_sync_metric_function"
)
assert func == my_sync_metric_function
@mock.patch("importlib.import_module", side_effect=ImportError)
def test_get_metric_function_module_not_found(mock_import_module):
"""Tests that _get_metric_function raises ImportError for non-existent module."""
with pytest.raises(ImportError):
_get_metric_function("non_existent_module.my_sync_metric_function")
@mock.patch("importlib.import_module")
def test_get_metric_function_function_not_found(mock_import_module):
"""Tests that _get_metric_function raises ImportError for non-existent function."""
mock_import_module.return_value = object()
with pytest.raises(ImportError):
_get_metric_function(
"google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.non_existent_function"
)
def test_get_metric_function_malformed_path():
"""Tests that _get_metric_function raises ImportError for malformed path."""
with pytest.raises(ImportError):
_get_metric_function("malformed_path")
@mock.patch(
"google.adk.evaluation.custom_metric_evaluator._get_metric_function",
return_value=my_sync_metric_function,
)
@pytest.mark.asyncio
async def test_custom_metric_evaluator_sync_function(mock_get_metric_function):
"""Tests that _CustomMetricEvaluator works with a sync metric function."""
eval_metric = EvalMetric(metric_name="sync_metric")
evaluator = _CustomMetricEvaluator(
eval_metric=eval_metric,
custom_function_path="google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.my_sync_metric_function",
)
result = await evaluator.evaluate_invocations([], None)
assert result.overall_score == 1.0
@mock.patch(
"google.adk.evaluation.custom_metric_evaluator._get_metric_function",
return_value=my_async_metric_function,
)
@pytest.mark.asyncio
async def test_custom_metric_evaluator_async_function(mock_get_metric_function):
"""Tests that _CustomMetricEvaluator works with an async metric function."""
eval_metric = EvalMetric(metric_name="async_metric")
evaluator = _CustomMetricEvaluator(
eval_metric=eval_metric,
custom_function_path="google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.my_async_metric_function",
)
result = await evaluator.evaluate_invocations([], None)
assert result.overall_score == 0.5