You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Remove overall_eval_status calculation from _CustomMetricEvaluator and add threshold to custom metric function expected signature
Co-authored-by: Joseph Pagadora <jcpagadora@google.com> PiperOrigin-RevId: 861268984
This commit is contained in:
committed by
Copybara-Service
parent
85434e293f
commit
553e376718
@@ -24,7 +24,6 @@ from typing_extensions import override
|
||||
from .eval_case import ConversationScenario
|
||||
from .eval_case import Invocation
|
||||
from .eval_metrics import EvalMetric
|
||||
from .eval_metrics import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
|
||||
@@ -44,12 +43,6 @@ def _get_metric_function(
|
||||
) from e
|
||||
|
||||
|
||||
def _get_eval_status(score: Optional[float], threshold: float) -> EvalStatus:
|
||||
if score is None:
|
||||
return EvalStatus.NOT_EVALUATED
|
||||
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
|
||||
|
||||
|
||||
class _CustomMetricEvaluator(Evaluator):
|
||||
"""Evaluator for custom metrics."""
|
||||
|
||||
@@ -64,16 +57,20 @@ class _CustomMetricEvaluator(Evaluator):
|
||||
expected_invocations: Optional[list[Invocation]],
|
||||
conversation_scenario: Optional[ConversationScenario] = None,
|
||||
) -> EvaluationResult:
|
||||
eval_metric = self._eval_metric.model_copy(deep=True)
|
||||
eval_metric.threshold = None
|
||||
if inspect.iscoroutinefunction(self._metric_function):
|
||||
eval_result = await self._metric_function(
|
||||
actual_invocations, expected_invocations, conversation_scenario
|
||||
eval_metric,
|
||||
actual_invocations,
|
||||
expected_invocations,
|
||||
conversation_scenario,
|
||||
)
|
||||
else:
|
||||
eval_result = self._metric_function(
|
||||
actual_invocations, expected_invocations, conversation_scenario
|
||||
eval_metric,
|
||||
actual_invocations,
|
||||
expected_invocations,
|
||||
conversation_scenario,
|
||||
)
|
||||
|
||||
eval_result.overall_eval_status = _get_eval_status(
|
||||
eval_result.overall_score, self._eval_metric.threshold
|
||||
)
|
||||
return eval_result
|
||||
|
||||
@@ -258,9 +258,11 @@ class EvalMetric(EvalBaseModel):
|
||||
description="The name of the metric.",
|
||||
)
|
||||
|
||||
threshold: float = Field(
|
||||
threshold: Optional[float] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A threshold value. Each metric decides how to interpret this"
|
||||
"This field will be deprecated soon. Please use `criterion` instead."
|
||||
" A threshold value. Each metric decides how to interpret this"
|
||||
" threshold."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.evaluation.custom_metric_evaluator import _CustomMetricEvaluator
|
||||
from google.adk.evaluation.custom_metric_evaluator import _get_metric_function
|
||||
from google.adk.evaluation.eval_case import ConversationScenario
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.eval_metrics import EvalMetric
|
||||
from google.adk.evaluation.evaluator import EvaluationResult
|
||||
import pytest
|
||||
|
||||
|
||||
def my_sync_metric_function(
|
||||
eval_metric: EvalMetric,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: list[Invocation] | None,
|
||||
conversation_scenario: ConversationScenario | None,
|
||||
) -> EvaluationResult:
|
||||
"""Sync metric function for testing."""
|
||||
return EvaluationResult(overall_score=1.0)
|
||||
|
||||
|
||||
async def my_async_metric_function(
|
||||
eval_metric: EvalMetric,
|
||||
actual_invocations: list[Invocation],
|
||||
expected_invocations: list[Invocation] | None,
|
||||
conversation_scenario: ConversationScenario | None,
|
||||
) -> EvaluationResult:
|
||||
"""Async metric function for testing."""
|
||||
return EvaluationResult(overall_score=0.5)
|
||||
|
||||
|
||||
@mock.patch("importlib.import_module")
|
||||
def test_get_metric_function_success(mock_import_module):
|
||||
"""Tests that _get_metric_function successfully returns a function."""
|
||||
mock_module = mock.MagicMock()
|
||||
mock_module.my_sync_metric_function = my_sync_metric_function
|
||||
mock_import_module.return_value = mock_module
|
||||
func = _get_metric_function(
|
||||
"test_custom_metric_evaluator.my_sync_metric_function"
|
||||
)
|
||||
assert func == my_sync_metric_function
|
||||
|
||||
|
||||
@mock.patch("importlib.import_module", side_effect=ImportError)
|
||||
def test_get_metric_function_module_not_found(mock_import_module):
|
||||
"""Tests that _get_metric_function raises ImportError for non-existent module."""
|
||||
with pytest.raises(ImportError):
|
||||
_get_metric_function("non_existent_module.my_sync_metric_function")
|
||||
|
||||
|
||||
@mock.patch("importlib.import_module")
|
||||
def test_get_metric_function_function_not_found(mock_import_module):
|
||||
"""Tests that _get_metric_function raises ImportError for non-existent function."""
|
||||
mock_import_module.return_value = object()
|
||||
with pytest.raises(ImportError):
|
||||
_get_metric_function(
|
||||
"google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.non_existent_function"
|
||||
)
|
||||
|
||||
|
||||
def test_get_metric_function_malformed_path():
|
||||
"""Tests that _get_metric_function raises ImportError for malformed path."""
|
||||
with pytest.raises(ImportError):
|
||||
_get_metric_function("malformed_path")
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.evaluation.custom_metric_evaluator._get_metric_function",
|
||||
return_value=my_sync_metric_function,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_metric_evaluator_sync_function(mock_get_metric_function):
|
||||
"""Tests that _CustomMetricEvaluator works with a sync metric function."""
|
||||
eval_metric = EvalMetric(metric_name="sync_metric")
|
||||
evaluator = _CustomMetricEvaluator(
|
||||
eval_metric=eval_metric,
|
||||
custom_function_path="google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.my_sync_metric_function",
|
||||
)
|
||||
result = await evaluator.evaluate_invocations([], None)
|
||||
assert result.overall_score == 1.0
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.evaluation.custom_metric_evaluator._get_metric_function",
|
||||
return_value=my_async_metric_function,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_metric_evaluator_async_function(mock_get_metric_function):
|
||||
"""Tests that _CustomMetricEvaluator works with an async metric function."""
|
||||
eval_metric = EvalMetric(metric_name="async_metric")
|
||||
evaluator = _CustomMetricEvaluator(
|
||||
eval_metric=eval_metric,
|
||||
custom_function_path="google.adk.tests.unittests.evaluation.test_custom_metric_evaluator.my_async_metric_function",
|
||||
)
|
||||
result = await evaluator.evaluate_invocations([], None)
|
||||
assert result.overall_score == 0.5
|
||||
Reference in New Issue
Block a user