You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Added support for InOrder and AnyOrder match in ToolTrajectoryAvgScore Metric
Co-authored-by: Ankur Sharma <ankusharma@google.com> PiperOrigin-RevId: 831413968
This commit is contained in:
committed by
Copybara-Service
parent
b2c8ba5806
commit
e2d3b2d862
@@ -150,6 +150,76 @@ class HallucinationsCriterion(BaseCriterion):
|
||||
)
|
||||
|
||||
|
||||
class ToolTrajectoryCriterion(BaseCriterion):
|
||||
"""Criterion to use when evaluating agent's tool trajectories with a reference one."""
|
||||
|
||||
class MatchType(Enum):
|
||||
"""The type of Match between actual and expected tool call trajectories."""
|
||||
|
||||
EXACT = 0
|
||||
"""Requires a perfect match between the actual and expected tool calls."""
|
||||
|
||||
IN_ORDER = 1
|
||||
"""Requires the actual tool calls to be in the same order as expected tools,
|
||||
with allowance for extra tool calls to have happened.
|
||||
|
||||
This criteria is useful in assuring if certain key actions/tool calls
|
||||
occur and in certain order, leaving some scope for other tools calls to
|
||||
happen as well.
|
||||
|
||||
Example 1: Set of actual vs expected tool calls that satisfies the criteria:
|
||||
|
||||
Expected tools calls: [T1, T2, T3]
|
||||
Actual tool calls: [T1, T1.1, T2, T2.1, T2.2, T3, T3.1]
|
||||
|
||||
This satisfies, as the tools T1, T2 and T3 happened in the "Actual" and in
|
||||
the same order.
|
||||
|
||||
Example 2: Set of actual vs expected tool calls that don't satisfy the
|
||||
criteria:
|
||||
|
||||
Expected tools calls: [T1, T2, T3, T4]
|
||||
Actual tool calls: [T1, T1.1, T2, T2.1, T2.2, T3, T3.1]
|
||||
|
||||
While the tool calls T1, T2 and T3 happened in the "Actual" and in
|
||||
the same order as "Expected", but the the tool calls T4 is missing.
|
||||
"""
|
||||
|
||||
ANY_ORDER = 2
|
||||
"""Requires the actual tool calls to be in the any order as expected tools,
|
||||
with allowance for extra tool calls to have happened.
|
||||
|
||||
This criteria is helpful for cases where multiple tool calls about the same
|
||||
concept occur, like your agent issues 5 search queries. You don't really
|
||||
care the order in which the search queries are issues, till they occur.
|
||||
|
||||
Example 1: Set of actual vs expected tool calls that satisfies the criteria:
|
||||
|
||||
Expected tools calls: [T1, T2, T3]
|
||||
Actual tool calls: [T2, T2.1, T1, T1.1, T1.2, T3, T3.1]
|
||||
|
||||
This satisfies, as the tools T1, T2 and T3 happened in the "Actual" and
|
||||
are also present in expected. Note that the order is different.
|
||||
|
||||
Example 2: Set of actual vs expected tool calls that don't satisfy the
|
||||
criteria:
|
||||
|
||||
Expected tools calls: [T1, T2, T3, T4]
|
||||
Actual tool calls: [T1, T1.1, T2, T2.1, T2.2, T3, T3.1]
|
||||
|
||||
While the tool calls T1, T2 and T3 happened in the "Actual" and in
|
||||
the same order as "Expected", but the the tool calls T4 is missing.
|
||||
"""
|
||||
|
||||
match_type: MatchType = Field(
|
||||
default=MatchType.EXACT,
|
||||
description=(
|
||||
"The type of Match between actual and expected tool call"
|
||||
" trajectories."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EvalMetric(EvalBaseModel):
|
||||
"""A metric used to evaluate a particular aspect of an eval case."""
|
||||
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types as genai_types
|
||||
from pydantic import ValidationError
|
||||
from typing_extensions import override
|
||||
|
||||
from .eval_case import get_all_tool_calls
|
||||
@@ -26,14 +29,43 @@ from .eval_metrics import Interval
|
||||
from .eval_metrics import MetricInfo
|
||||
from .eval_metrics import MetricValueInfo
|
||||
from .eval_metrics import PrebuiltMetrics
|
||||
from .eval_metrics import ToolTrajectoryCriterion
|
||||
from .evaluator import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
from .evaluator import PerInvocationResult
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class TrajectoryEvaluator(Evaluator):
|
||||
"""Evaluates tool use trajectories for accuracy."""
|
||||
"""Evaluates tool use trajectories for accuracy.
|
||||
|
||||
This evaluator compares the sequence of tools called by the agent against a
|
||||
list of expected calls and computes an average score based on one of the match
|
||||
types: `EXACT`, `IN_ORDER`, or `ANY_ORDER`.
|
||||
|
||||
For each invocation being evaluated, this evaluator compares the list of
|
||||
tool calls produced by the agent with the list of expected tool calls using
|
||||
one of three match types. If the tool calls match based on the selected match
|
||||
type, a score of 1.0 is awarded for that invocation, otherwise the score is
|
||||
0.0. The final value is the average of these scores across all
|
||||
invocations in the eval case.
|
||||
|
||||
The comparison can be done using one of following match types:
|
||||
- `EXACT`: Requires a perfect match between the actual and expected tool
|
||||
calls, with no extra or missing tool calls.
|
||||
- `IN_ORDER`: Requires all tool calls from the expected list to be present
|
||||
in the actual list, in the same order, but allows for other tool calls
|
||||
to appear in between.
|
||||
- `ANY_ORDER`: Requires all tool calls from the expected list to be
|
||||
present in the actual list, in any order, and allows for other tool
|
||||
calls to appear in between.
|
||||
"""
|
||||
|
||||
criterion_type: ClassVar[type[ToolTrajectoryCriterion]] = (
|
||||
ToolTrajectoryCriterion
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,10 +78,25 @@ class TrajectoryEvaluator(Evaluator):
|
||||
" specified."
|
||||
)
|
||||
|
||||
if eval_metric:
|
||||
threshold = eval_metric.threshold
|
||||
|
||||
self._threshold = threshold
|
||||
if eval_metric and eval_metric.criterion:
|
||||
try:
|
||||
criterion = TrajectoryEvaluator.criterion_type.model_validate(
|
||||
eval_metric.criterion.model_dump()
|
||||
)
|
||||
self._threshold = criterion.threshold
|
||||
self._match_type = criterion.match_type
|
||||
except ValidationError as e:
|
||||
expected_criterion_type_error = ValueError(
|
||||
f"`{eval_metric.metric_name}` metric expects a criterion of type"
|
||||
f" `{TrajectoryEvaluator.criterion_type}`."
|
||||
)
|
||||
raise expected_criterion_type_error from e
|
||||
elif eval_metric:
|
||||
self._threshold = eval_metric.threshold
|
||||
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
|
||||
else:
|
||||
self._threshold = threshold
|
||||
self._match_type = ToolTrajectoryCriterion.MatchType.EXACT
|
||||
|
||||
@staticmethod
|
||||
def get_metric_info() -> MetricInfo:
|
||||
@@ -82,14 +129,7 @@ class TrajectoryEvaluator(Evaluator):
|
||||
per_invocation_results = []
|
||||
|
||||
for actual, expected in zip(actual_invocations, expected_invocations):
|
||||
actual_tool_uses = get_all_tool_calls(actual.intermediate_data)
|
||||
expected_tool_uses = get_all_tool_calls(expected.intermediate_data)
|
||||
|
||||
tool_use_accuracy = (
|
||||
1.0
|
||||
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
|
||||
else 0.0
|
||||
)
|
||||
tool_use_accuracy = self._calculate_tool_use_accuracy(actual, expected)
|
||||
per_invocation_results.append(
|
||||
PerInvocationResult(
|
||||
actual_invocation=actual,
|
||||
@@ -111,11 +151,128 @@ class TrajectoryEvaluator(Evaluator):
|
||||
|
||||
return EvaluationResult()
|
||||
|
||||
def _are_tool_calls_equal(
|
||||
def _calculate_tool_use_accuracy(
|
||||
self,
|
||||
actual_invocation: Invocation,
|
||||
expected_invocation: Invocation,
|
||||
) -> float:
|
||||
"""Calculates tool use accuracy for a single invocation."""
|
||||
actual_tool_uses = get_all_tool_calls(actual_invocation.intermediate_data)
|
||||
expected_tool_uses = get_all_tool_calls(
|
||||
expected_invocation.intermediate_data
|
||||
)
|
||||
|
||||
tool_use_match_status = False
|
||||
if self._match_type == ToolTrajectoryCriterion.MatchType.EXACT:
|
||||
tool_use_match_status = self._are_tool_calls_exact_match(
|
||||
actual_tool_uses, expected_tool_uses
|
||||
)
|
||||
elif self._match_type == ToolTrajectoryCriterion.MatchType.IN_ORDER:
|
||||
tool_use_match_status = self._are_tool_calls_in_order_match(
|
||||
actual_tool_uses, expected_tool_uses
|
||||
)
|
||||
elif self._match_type == ToolTrajectoryCriterion.MatchType.ANY_ORDER:
|
||||
tool_use_match_status = self._are_tool_calls_any_order_match(
|
||||
actual_tool_uses, expected_tool_uses
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported match type {self._match_type}")
|
||||
|
||||
return 1.0 if tool_use_match_status else 0.0
|
||||
|
||||
def _are_tool_calls_in_order_match(
|
||||
self,
|
||||
actual_tool_calls: list[genai_types.FunctionCall],
|
||||
expected_tool_calls: list[genai_types.FunctionCall],
|
||||
) -> bool:
|
||||
"""Checks if expected tool calls appear in actual tool calls in order.
|
||||
|
||||
This method implements IN_ORDER match type. It allows for additional
|
||||
tool calls in actual_tool_calls, as long as all expected tool calls are
|
||||
present in the same order.
|
||||
|
||||
Args:
|
||||
actual_tool_calls: A list of tool calls that actually happened.
|
||||
expected_tool_calls: A list of tool calls that were expected to happen.
|
||||
|
||||
Returns:
|
||||
True if actual tool calls match expected tool calls in order,
|
||||
False otherwise.
|
||||
"""
|
||||
if not expected_tool_calls:
|
||||
return True
|
||||
if not actual_tool_calls and expected_tool_calls:
|
||||
return False
|
||||
|
||||
expected_it = iter(expected_tool_calls)
|
||||
try:
|
||||
current_expected = next(expected_it)
|
||||
for actual in actual_tool_calls:
|
||||
if (
|
||||
actual.name == current_expected.name
|
||||
and actual.args == current_expected.args
|
||||
):
|
||||
current_expected = next(expected_it)
|
||||
except StopIteration:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _are_tool_calls_any_order_match(
|
||||
self,
|
||||
actual_tool_calls: list[genai_types.FunctionCall],
|
||||
expected_tool_calls: list[genai_types.FunctionCall],
|
||||
) -> bool:
|
||||
"""Checks if expected tool calls appear in actual tool calls in any order.
|
||||
|
||||
This method implements ANY_ORDER match type. It allows for additional
|
||||
tool calls in actual_tool_calls, as long as all expected tool calls are
|
||||
present.
|
||||
|
||||
Args:
|
||||
actual_tool_calls: A list of tool calls that actually happened.
|
||||
expected_tool_calls: A list of tool calls that were expected to happen.
|
||||
|
||||
Returns:
|
||||
True if actual tool calls contain all expected tool calls,
|
||||
False otherwise.
|
||||
"""
|
||||
if not expected_tool_calls:
|
||||
return True
|
||||
if not actual_tool_calls and expected_tool_calls:
|
||||
return False
|
||||
|
||||
actual_tool_calls_copy = list(actual_tool_calls)
|
||||
for expected in expected_tool_calls:
|
||||
found = False
|
||||
for i, actual in enumerate(actual_tool_calls_copy):
|
||||
if actual.name == expected.name and actual.args == expected.args:
|
||||
actual_tool_calls_copy.pop(i)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _are_tool_calls_exact_match(
|
||||
self,
|
||||
actual_tool_calls: list[genai_types.FunctionCall],
|
||||
expected_tool_calls: list[genai_types.FunctionCall],
|
||||
) -> bool:
|
||||
"""Checks if actual tool calls exactly match expected tool calls.
|
||||
|
||||
This method implements EXACT match type. It requires that
|
||||
actual_tool_calls and expected_tool_calls have the same tool calls in
|
||||
the same order, with no extra or missing tool calls.
|
||||
|
||||
Args:
|
||||
actual_tool_calls: A list of tool calls that actually happened.
|
||||
expected_tool_calls: A list of tool calls that were expected to happen.
|
||||
|
||||
Returns:
|
||||
True if actual tool calls exactly match expected tool calls,
|
||||
False otherwise.
|
||||
"""
|
||||
if len(actual_tool_calls) != len(expected_tool_calls):
|
||||
return False
|
||||
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
|
||||
from google.adk.evaluation.eval_case import IntermediateData
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.eval_metrics import EvalMetric
|
||||
from google.adk.evaluation.eval_metrics import PrebuiltMetrics
|
||||
from google.adk.evaluation.eval_metrics import ToolTrajectoryCriterion
|
||||
from google.adk.evaluation.evaluator import EvalStatus
|
||||
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
from google.genai import types as genai_types
|
||||
@@ -41,7 +43,16 @@ def test_get_metric_info():
|
||||
@pytest.fixture
|
||||
def evaluator() -> TrajectoryEvaluator:
|
||||
"""Returns a TrajectoryEvaluator."""
|
||||
return TrajectoryEvaluator(threshold=0.5)
|
||||
return TrajectoryEvaluator(
|
||||
eval_metric=EvalMetric(
|
||||
threshold=0.5,
|
||||
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
|
||||
criterion=ToolTrajectoryCriterion(
|
||||
threshold=0.5,
|
||||
match_type=ToolTrajectoryCriterion.MatchType.EXACT,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_evaluate_invocations_equal_tool_calls(evaluator: TrajectoryEvaluator):
|
||||
@@ -176,6 +187,220 @@ def test_evaluate_invocations_multiple_invocations(
|
||||
assert result.per_invocation_results[1].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_order_evaluator() -> TrajectoryEvaluator:
|
||||
"""Returns a TrajectoryEvaluator for IN_ORDER match."""
|
||||
return TrajectoryEvaluator(
|
||||
eval_metric=EvalMetric(
|
||||
threshold=0.5,
|
||||
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
|
||||
criterion=ToolTrajectoryCriterion(
|
||||
threshold=0.5,
|
||||
match_type=ToolTrajectoryCriterion.MatchType.IN_ORDER,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_evaluate_invocations_in_order_match_with_extra_tool_calls(
|
||||
in_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with IN_ORDER match type and extra tool calls."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t1_1 = genai_types.FunctionCall(name="t1_1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t2_1 = genai_types.FunctionCall(name="t2_1", args={})
|
||||
t3 = genai_types.FunctionCall(name="t3", args={})
|
||||
t3_1 = genai_types.FunctionCall(name="t3_1", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(
|
||||
tool_uses=[t1, t1_1, t2, t2_1, t3, t3_1]
|
||||
),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t3]),
|
||||
)
|
||||
result = in_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 1.0
|
||||
assert result.overall_eval_status == EvalStatus.PASSED
|
||||
assert result.per_invocation_results[0].score == 1.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.PASSED
|
||||
|
||||
|
||||
def test_evaluate_invocations_in_order_match_fails_with_missing_tool_call(
|
||||
in_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with IN_ORDER match type and missing tool call."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t1_1 = genai_types.FunctionCall(name="t1_1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t2_1 = genai_types.FunctionCall(name="t2_1", args={})
|
||||
t3_1 = genai_types.FunctionCall(name="t3_1", args={})
|
||||
t4 = genai_types.FunctionCall(name="t4", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t1_1, t2, t2_1, t3_1]),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t4]),
|
||||
)
|
||||
result = in_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_in_order_match_fails_with_wrong_order(
|
||||
in_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with IN_ORDER match type and wrong order."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t3 = genai_types.FunctionCall(name="t3", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t3, t2]),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t3]),
|
||||
)
|
||||
result = in_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_order_evaluator() -> TrajectoryEvaluator:
|
||||
"""Returns a TrajectoryEvaluator for ANY_ORDER match."""
|
||||
return TrajectoryEvaluator(
|
||||
eval_metric=EvalMetric(
|
||||
threshold=0.5,
|
||||
metric_name=PrebuiltMetrics.TOOL_TRAJECTORY_AVG_SCORE.value,
|
||||
criterion=ToolTrajectoryCriterion(
|
||||
threshold=0.5,
|
||||
match_type=ToolTrajectoryCriterion.MatchType.ANY_ORDER,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_evaluate_invocations_any_order_match_with_extra_tool_calls_different_order(
|
||||
any_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with ANY_ORDER match type and extra tool calls."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t1_1 = genai_types.FunctionCall(name="t1_1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t2_1 = genai_types.FunctionCall(name="t2_1", args={})
|
||||
t3 = genai_types.FunctionCall(name="t3", args={})
|
||||
t3_1 = genai_types.FunctionCall(name="t3_1", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(
|
||||
tool_uses=[t2, t2_1, t1, t1_1, t3, t3_1]
|
||||
),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t3]),
|
||||
)
|
||||
result = any_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 1.0
|
||||
assert result.overall_eval_status == EvalStatus.PASSED
|
||||
assert result.per_invocation_results[0].score == 1.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.PASSED
|
||||
|
||||
|
||||
def test_evaluate_invocations_any_order_match_fails_with_missing_tool_call(
|
||||
any_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with ANY_ORDER match type and missing tool call."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t1_1 = genai_types.FunctionCall(name="t1_1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t2_1 = genai_types.FunctionCall(name="t2_1", args={})
|
||||
t3_1 = genai_types.FunctionCall(name="t3_1", args={})
|
||||
t4 = genai_types.FunctionCall(name="t4", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t1_1, t2, t2_1, t3_1]),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t4]),
|
||||
)
|
||||
result = any_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_any_order_match_with_duplicates(
|
||||
any_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with ANY_ORDER match type with duplicates."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t3 = genai_types.FunctionCall(name="t3", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t3, t1]),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t1]),
|
||||
)
|
||||
result = any_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 1.0
|
||||
assert result.overall_eval_status == EvalStatus.PASSED
|
||||
assert result.per_invocation_results[0].score == 1.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.PASSED
|
||||
|
||||
|
||||
def test_evaluate_invocations_any_order_match_fails_with_duplicates_missing(
|
||||
any_order_evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with ANY_ORDER match type with missing duplicates."""
|
||||
t1 = genai_types.FunctionCall(name="t1", args={})
|
||||
t2 = genai_types.FunctionCall(name="t2", args={})
|
||||
t3 = genai_types.FunctionCall(name="t3", args={})
|
||||
actual_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t3]),
|
||||
)
|
||||
expected_invocation = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[t1, t2, t1]),
|
||||
)
|
||||
result = any_order_evaluator.evaluate_invocations(
|
||||
[actual_invocation], [expected_invocation]
|
||||
)
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_no_invocations(evaluator: TrajectoryEvaluator):
|
||||
"""Tests evaluate_invocations with no invocations."""
|
||||
result = evaluator.evaluate_invocations([], [])
|
||||
|
||||
Reference in New Issue
Block a user