You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Remove deprecated static methods from TrajectoryEvaluator
This change removes the `evaluate`, `_evaluate_row`, `are_tools_equal`, `_remove_tool_outputs`, `_report_failures`, and `_print_results` static methods from `TrajectoryEvaluator`, along with their corresponding unit tests. These methods were previously marked as deprecated. PiperOrigin-RevId: 817477494
This commit is contained in:
committed by
Copybara-Service
parent
81913c85f4
commit
64646e0002
@@ -14,13 +14,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types as genai_types
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from typing_extensions import deprecated
|
||||
from typing_extensions import override
|
||||
|
||||
from .eval_case import get_all_tool_calls
|
||||
@@ -30,7 +26,6 @@ from .eval_metrics import Interval
|
||||
from .eval_metrics import MetricInfo
|
||||
from .eval_metrics import MetricValueInfo
|
||||
from .eval_metrics import PrebuiltMetrics
|
||||
from .evaluation_constants import EvalConstants
|
||||
from .evaluator import EvalStatus
|
||||
from .evaluator import EvaluationResult
|
||||
from .evaluator import Evaluator
|
||||
@@ -129,170 +124,3 @@ class TrajectoryEvaluator(Evaluator):
|
||||
|
||||
def _get_eval_status(self, score: float):
|
||||
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
"This method has been deprecated and will be removed soon. Please use"
|
||||
" evaluate_invocations instead."
|
||||
)
|
||||
def evaluate(
|
||||
eval_dataset: list[list[dict[str, Any]]],
|
||||
*,
|
||||
print_detailed_results: bool = False,
|
||||
):
|
||||
r"""Returns the mean tool use accuracy of the eval dataset.
|
||||
|
||||
Tool use accuracy is calculated by comparing the expected and the actual
|
||||
tool use trajectories. An exact match scores a 1, 0 otherwise. The final
|
||||
number is an average of these individual scores.
|
||||
|
||||
Value range: [0, 1], where 0 means none of the tool use entries aligned,
|
||||
and 1 would mean all of them aligned. Higher value is good.
|
||||
|
||||
Args:
|
||||
eval_dataset: The dataset that will be evaluated.
|
||||
print_detailed_results: Prints detailed results on the console. This is
|
||||
usually helpful during debugging.
|
||||
|
||||
A note on eval_dataset:
|
||||
The dataset should be a list session, where each session is represented
|
||||
as a list of interaction that need evaluation. Each evaluation is
|
||||
represented as a dictionary that is expected to have values for the
|
||||
following keys:
|
||||
1) query
|
||||
2) response
|
||||
3) acutal_tool_use
|
||||
4) expected_tool_use
|
||||
|
||||
Here is a sample eval_dataset value with one entry:
|
||||
|
||||
[
|
||||
[
|
||||
{
|
||||
"query": "Roll a 16 sided dice for me",
|
||||
"response": "I rolled a 16 sided die and got 13.\n",
|
||||
"expected_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 16
|
||||
}
|
||||
}
|
||||
],
|
||||
"acutal_tool_use": [
|
||||
{
|
||||
"tool_name": "roll_die",
|
||||
"tool_input": {
|
||||
"sides": 16
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
"""
|
||||
if not eval_dataset:
|
||||
raise ValueError("The evaluation dataset is empty.")
|
||||
|
||||
results_df = pd.DataFrame(
|
||||
columns=[
|
||||
"query",
|
||||
"response",
|
||||
"actual_tool_use",
|
||||
"expected_tool_use",
|
||||
"tool_use_accuracy",
|
||||
]
|
||||
)
|
||||
failures = []
|
||||
|
||||
for conversation in eval_dataset:
|
||||
for index, row in enumerate(conversation):
|
||||
new_row, failure = TrajectoryEvaluator._evaluate_row(row)
|
||||
results_df = pd.concat(
|
||||
[results_df, pd.DataFrame([new_row])], ignore_index=True
|
||||
)
|
||||
if failure:
|
||||
failure["turn"] = index + 1
|
||||
failures.append(failure)
|
||||
|
||||
TrajectoryEvaluator._report_failures(failures)
|
||||
|
||||
if print_detailed_results:
|
||||
TrajectoryEvaluator._print_results(results_df)
|
||||
|
||||
return results_df["tool_use_accuracy"].mean()
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_row(row):
|
||||
# We don't evaluate the mock tool outputs.
|
||||
expected = TrajectoryEvaluator._remove_tool_outputs(
|
||||
row["expected_tool_use"]
|
||||
)
|
||||
actual = row["actual_tool_use"]
|
||||
tool_use_accuracy = (
|
||||
1.0 if TrajectoryEvaluator.are_tools_equal(actual, expected) else 0.0
|
||||
)
|
||||
|
||||
new_row = {
|
||||
"query": row["query"],
|
||||
"response": row["response"],
|
||||
"actual_tool_use": actual,
|
||||
"expected_tool_use": expected,
|
||||
"tool_use_accuracy": tool_use_accuracy,
|
||||
}
|
||||
failure = (
|
||||
None
|
||||
if tool_use_accuracy == 1.0
|
||||
else {"query": row["query"], "actual": actual, "expected": expected}
|
||||
)
|
||||
return new_row, failure
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
"are_tools_equal is deprecated and will be removed soon. Please use"
|
||||
" TrajectoryEvaluator._are_tool_calls_equal instead."
|
||||
)
|
||||
def are_tools_equal(list_a_original, list_b_original):
|
||||
# Remove other entries that we don't want to evaluate
|
||||
list_a = [
|
||||
{"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
|
||||
for tool in list_a_original
|
||||
]
|
||||
|
||||
list_b = [
|
||||
{"tool_name": tool["tool_name"], "tool_input": tool["tool_input"]}
|
||||
for tool in list_b_original
|
||||
]
|
||||
|
||||
return list_a == list_b
|
||||
|
||||
@staticmethod
|
||||
def _remove_tool_outputs(tool_use_list):
|
||||
"""Removes 'mock_tool_output' from each dictionary in the list."""
|
||||
result = []
|
||||
for tool_use in tool_use_list:
|
||||
new_tool_use = (
|
||||
tool_use.copy()
|
||||
) # Create a copy to avoid modifying the original
|
||||
new_tool_use.pop(
|
||||
EvalConstants.MOCK_TOOL_OUTPUT, None
|
||||
) # Remove 'tool_output' if it exists
|
||||
result.append(new_tool_use)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _report_failures(failures):
|
||||
if failures:
|
||||
print("Failures:")
|
||||
for failure in failures:
|
||||
print(f"""{{
|
||||
"turn": {failure["turn"]},
|
||||
"query": '{failure["query"]}',
|
||||
"actual": {failure["actual"]},
|
||||
"expected_tool_use": {failure["expected"]},
|
||||
}}
|
||||
""")
|
||||
|
||||
@staticmethod
|
||||
def _print_results(results_df):
|
||||
print(tabulate(results_df, headers="keys", tablefmt="grid"))
|
||||
|
||||
@@ -14,263 +14,18 @@
|
||||
|
||||
"""Testings for the Trajectory Evaluator."""
|
||||
|
||||
import math
|
||||
|
||||
from google.adk.evaluation.eval_case import IntermediateData
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.eval_metrics import PrebuiltMetrics
|
||||
from google.adk.evaluation.evaluator import EvalStatus
|
||||
from google.adk.evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
from google.genai import types as genai_types
|
||||
import pytest
|
||||
|
||||
# Define reusable tool call structures
|
||||
TOOL_ROLL_DICE_16 = {"tool_name": "roll_die", "tool_input": {"sides": 16}}
|
||||
TOOL_ROLL_DICE_6 = {"tool_name": "roll_die", "tool_input": {"sides": 6}}
|
||||
TOOL_GET_WEATHER = {
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
}
|
||||
TOOL_GET_WEATHER_SF = {
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "SF"},
|
||||
}
|
||||
|
||||
# Sample data for turns
|
||||
TURN_MATCH = {
|
||||
"query": "Q1",
|
||||
"response": "R1",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_16],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MISMATCH_INPUT = {
|
||||
"query": "Q2",
|
||||
"response": "R2",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MISMATCH_NAME = {
|
||||
"query": "Q3",
|
||||
"response": "R3",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_ROLL_DICE_16],
|
||||
}
|
||||
TURN_MATCH_MULTIPLE = {
|
||||
"query": "Q4",
|
||||
"response": "R4",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MISMATCH_ORDER = {
|
||||
"query": "Q5",
|
||||
"response": "R5",
|
||||
"actual_tool_use": [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MISMATCH_LENGTH_ACTUAL_LONGER = {
|
||||
"query": "Q6",
|
||||
"response": "R6",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||
}
|
||||
TURN_MISMATCH_LENGTH_EXPECTED_LONGER = {
|
||||
"query": "Q7",
|
||||
"response": "R7",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6],
|
||||
}
|
||||
TURN_MATCH_WITH_MOCK_OUTPUT = {
|
||||
"query": "Q8",
|
||||
"response": "R8",
|
||||
"actual_tool_use": [TOOL_GET_WEATHER_SF],
|
||||
"expected_tool_use": [
|
||||
{**TOOL_GET_WEATHER_SF, "mock_tool_output": "Sunny"}
|
||||
], # Add mock output to expected
|
||||
}
|
||||
TURN_MATCH_EMPTY_TOOLS = {
|
||||
"query": "Q9",
|
||||
"response": "R9",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [],
|
||||
}
|
||||
TURN_MISMATCH_EMPTY_VS_NONEMPTY = {
|
||||
"query": "Q10",
|
||||
"response": "R10",
|
||||
"actual_tool_use": [],
|
||||
"expected_tool_use": [TOOL_GET_WEATHER],
|
||||
}
|
||||
|
||||
|
||||
def test_evaluate_none_dataset_raises_value_error():
|
||||
"""Tests evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
TrajectoryEvaluator.evaluate(None)
|
||||
|
||||
|
||||
def test_evaluate_empty_dataset_raises_value_error():
|
||||
"""Tests evaluate function raises ValueError for an empty list."""
|
||||
with pytest.raises(ValueError, match="The evaluation dataset is empty."):
|
||||
TrajectoryEvaluator.evaluate([])
|
||||
|
||||
|
||||
def test_evaluate_single_turn_match():
|
||||
"""Tests evaluate function with one conversation, one turn, perfect match."""
|
||||
eval_dataset = [[TURN_MATCH]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_single_turn_mismatch():
|
||||
"""Tests evaluate function with one conversation, one turn, mismatch."""
|
||||
eval_dataset = [[TURN_MISMATCH_INPUT]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||
|
||||
|
||||
def test_evaluate_multiple_turns_all_match():
|
||||
"""Tests evaluate function with one conversation, multiple turns, all match."""
|
||||
eval_dataset = [[TURN_MATCH, TURN_MATCH_MULTIPLE, TURN_MATCH_EMPTY_TOOLS]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_multiple_turns_mixed():
|
||||
"""Tests evaluate function with one conversation, mixed match/mismatch turns."""
|
||||
eval_dataset = [
|
||||
[TURN_MATCH, TURN_MISMATCH_NAME, TURN_MATCH_MULTIPLE, TURN_MISMATCH_ORDER]
|
||||
]
|
||||
# Expected: (1.0 + 0.0 + 1.0 + 0.0) / 4 = 0.5
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||
|
||||
|
||||
def test_evaluate_multiple_conversations_mixed():
|
||||
"""Tests evaluate function with multiple conversations, mixed turns."""
|
||||
eval_dataset = [
|
||||
[TURN_MATCH, TURN_MISMATCH_INPUT], # Conv 1: 1.0, 0.0 -> Avg 0.5
|
||||
[TURN_MATCH_MULTIPLE], # Conv 2: 1.0 -> Avg 1.0
|
||||
[
|
||||
TURN_MISMATCH_ORDER,
|
||||
TURN_MISMATCH_LENGTH_ACTUAL_LONGER,
|
||||
TURN_MATCH,
|
||||
], # Conv 3: 0.0, 0.0, 1.0 -> Avg 1/3
|
||||
]
|
||||
# Expected: (1.0 + 0.0 + 1.0 + 0.0 + 0.0 + 1.0) / 6 = 3.0 / 6 = 0.5
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.5
|
||||
|
||||
|
||||
def test_evaluate_ignores_mock_tool_output_in_expected():
|
||||
"""Tests evaluate function correctly compares even if expected has mock_tool_output."""
|
||||
eval_dataset = [[TURN_MATCH_WITH_MOCK_OUTPUT]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_match_empty_tool_lists():
|
||||
"""Tests evaluate function correctly matches empty tool lists."""
|
||||
eval_dataset = [[TURN_MATCH_EMPTY_TOOLS]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_mismatch_empty_vs_nonempty():
|
||||
"""Tests evaluate function correctly mismatches empty vs non-empty tool lists."""
|
||||
eval_dataset = [[TURN_MISMATCH_EMPTY_VS_NONEMPTY]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 0.0
|
||||
eval_dataset_rev = [[{
|
||||
**TURN_MISMATCH_EMPTY_VS_NONEMPTY, # Swap actual/expected
|
||||
"actual_tool_use": [TOOL_GET_WEATHER],
|
||||
"expected_tool_use": [],
|
||||
}]]
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset_rev) == 0.0
|
||||
|
||||
|
||||
def test_evaluate_dataset_with_empty_conversation():
|
||||
"""Tests evaluate function handles dataset containing an empty conversation list."""
|
||||
eval_dataset = [[TURN_MATCH], []] # One valid conversation, one empty
|
||||
# Should only evaluate the first conversation -> 1.0 / 1 turn = 1.0
|
||||
assert TrajectoryEvaluator.evaluate(eval_dataset) == 1.0
|
||||
|
||||
|
||||
def test_evaluate_dataset_only_empty_conversation():
|
||||
"""Tests evaluate function handles dataset with only an empty conversation."""
|
||||
eval_dataset = [[]]
|
||||
# No rows evaluated, mean of empty series is NaN
|
||||
# Depending on desired behavior, this could be 0.0 or NaN. The code returns
|
||||
# NaN.
|
||||
assert math.isnan(TrajectoryEvaluator.evaluate(eval_dataset))
|
||||
|
||||
|
||||
def test_evaluate_print_detailed_results(capsys):
|
||||
"""Tests evaluate function runs with print_detailed_results=True and prints something."""
|
||||
eval_dataset = [[TURN_MATCH, TURN_MISMATCH_INPUT]]
|
||||
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "query" in captured.out # Check if the results table header is printed
|
||||
assert "R1" in captured.out # Check if some data is printed
|
||||
assert "Failures:" in captured.out # Check if failures header is printed
|
||||
assert "Q2" in captured.out # Check if the failing query is printed
|
||||
|
||||
|
||||
def test_evaluate_no_failures_print(capsys):
|
||||
"""Tests evaluate function does not print Failures section when all turns match."""
|
||||
eval_dataset = [[TURN_MATCH]]
|
||||
TrajectoryEvaluator.evaluate(eval_dataset, print_detailed_results=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "query" in captured.out # Results table should still print
|
||||
assert "Failures:" not in captured.out # Failures section should NOT print
|
||||
|
||||
|
||||
def test_are_tools_equal_identical():
|
||||
"""Tests are_tools_equal function with identical lists."""
|
||||
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_empty():
|
||||
"""Tests are_tools_equal function with empty lists."""
|
||||
assert TrajectoryEvaluator.are_tools_equal([], [])
|
||||
|
||||
|
||||
def test_are_tools_equal_different_order():
|
||||
"""Tests are_tools_equal function with same tools, different order."""
|
||||
list_a = [TOOL_ROLL_DICE_6, TOOL_GET_WEATHER]
|
||||
list_b = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_length():
|
||||
"""Tests are_tools_equal function with lists of different lengths."""
|
||||
list_a = [TOOL_GET_WEATHER, TOOL_ROLL_DICE_6]
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_input_values():
|
||||
"""Tests are_tools_equal function with different input values."""
|
||||
list_a = [TOOL_ROLL_DICE_16]
|
||||
list_b = [TOOL_ROLL_DICE_6]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_different_tool_names():
|
||||
"""Tests are_tools_equal function with different tool names."""
|
||||
list_a = [TOOL_ROLL_DICE_16]
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_ignores_extra_keys():
|
||||
"""Tests are_tools_equal function ignores keys other than tool_name/tool_input."""
|
||||
list_a = [{
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
"extra_key": "abc",
|
||||
}]
|
||||
list_b = [{
|
||||
"tool_name": "get_weather",
|
||||
"tool_input": {"location": "Paris"},
|
||||
"other_key": 123,
|
||||
}]
|
||||
assert TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
|
||||
|
||||
def test_are_tools_equal_one_empty_one_not():
|
||||
"""Tests are_tools_equal function with one empty list and one non-empty list."""
|
||||
list_a = []
|
||||
list_b = [TOOL_GET_WEATHER]
|
||||
assert not TrajectoryEvaluator.are_tools_equal(list_a, list_b)
|
||||
_USER_CONTENT = genai_types.Content(
|
||||
parts=[genai_types.Part(text="User input here.")]
|
||||
)
|
||||
|
||||
|
||||
def test_get_metric_info():
|
||||
@@ -281,3 +36,149 @@ def test_get_metric_info():
|
||||
)
|
||||
assert metric_info.metric_value_info.interval.min_value == 0.0
|
||||
assert metric_info.metric_value_info.interval.max_value == 1.0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def evaluator() -> TrajectoryEvaluator:
|
||||
"""Returns a TrajectoryEvaluator."""
|
||||
return TrajectoryEvaluator(threshold=0.5)
|
||||
|
||||
|
||||
def test_evaluate_invocations_equal_tool_calls(evaluator: TrajectoryEvaluator):
|
||||
"""Tests evaluate_invocations with equal tool calls."""
|
||||
tool_call = genai_types.FunctionCall(name="test_func", args={"arg1": "val1"})
|
||||
intermediate_data = IntermediateData(tool_uses=[tool_call])
|
||||
invocation = Invocation(
|
||||
user_content=_USER_CONTENT, intermediate_data=intermediate_data
|
||||
)
|
||||
result = evaluator.evaluate_invocations([invocation], [invocation])
|
||||
assert result.overall_score == 1.0
|
||||
assert result.overall_eval_status == EvalStatus.PASSED
|
||||
assert len(result.per_invocation_results) == 1
|
||||
assert result.per_invocation_results[0].score == 1.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.PASSED
|
||||
|
||||
|
||||
def test_evaluate_invocations_different_tool_call_names(
|
||||
evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with different tool call names."""
|
||||
tool_call1 = genai_types.FunctionCall(
|
||||
name="test_func1", args={"arg1": "val1"}
|
||||
)
|
||||
tool_call2 = genai_types.FunctionCall(
|
||||
name="test_func2", args={"arg1": "val1"}
|
||||
)
|
||||
invocation1 = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1]),
|
||||
)
|
||||
invocation2 = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call2]),
|
||||
)
|
||||
result = evaluator.evaluate_invocations([invocation1], [invocation2])
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_different_tool_call_args(
|
||||
evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with different tool call args."""
|
||||
tool_call1 = genai_types.FunctionCall(name="test_func", args={"arg1": "val1"})
|
||||
tool_call2 = genai_types.FunctionCall(name="test_func", args={"arg1": "val2"})
|
||||
invocation1 = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1]),
|
||||
)
|
||||
invocation2 = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call2]),
|
||||
)
|
||||
result = evaluator.evaluate_invocations([invocation1], [invocation2])
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_different_number_of_tool_calls(
|
||||
evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with different number of tool calls."""
|
||||
tool_call1 = genai_types.FunctionCall(name="test_func", args={"arg1": "val1"})
|
||||
tool_call2 = genai_types.FunctionCall(name="test_func", args={"arg1": "val1"})
|
||||
invocation1 = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1]),
|
||||
)
|
||||
invocation2 = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1, tool_call2]),
|
||||
)
|
||||
result = evaluator.evaluate_invocations([invocation1], [invocation2])
|
||||
assert result.overall_score == 0.0
|
||||
assert result.overall_eval_status == EvalStatus.FAILED
|
||||
assert result.per_invocation_results[0].score == 0.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_no_tool_calls(evaluator: TrajectoryEvaluator):
|
||||
"""Tests evaluate_invocations with no tool calls."""
|
||||
invocation = Invocation(
|
||||
user_content=_USER_CONTENT, intermediate_data=IntermediateData()
|
||||
)
|
||||
result = evaluator.evaluate_invocations([invocation], [invocation])
|
||||
assert result.overall_score == 1.0
|
||||
assert result.overall_eval_status == EvalStatus.PASSED
|
||||
assert result.per_invocation_results[0].score == 1.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.PASSED
|
||||
|
||||
|
||||
def test_evaluate_invocations_multiple_invocations(
|
||||
evaluator: TrajectoryEvaluator,
|
||||
):
|
||||
"""Tests evaluate_invocations with multiple invocations."""
|
||||
tool_call1 = genai_types.FunctionCall(
|
||||
name="test_func1", args={"arg1": "val1"}
|
||||
)
|
||||
tool_call2 = genai_types.FunctionCall(
|
||||
name="test_func2", args={"arg1": "val1"}
|
||||
)
|
||||
inv1_actual = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1]),
|
||||
)
|
||||
inv1_expected = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1]),
|
||||
)
|
||||
inv2_actual = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call1]),
|
||||
)
|
||||
inv2_expected = Invocation(
|
||||
user_content=_USER_CONTENT,
|
||||
intermediate_data=IntermediateData(tool_uses=[tool_call2]),
|
||||
)
|
||||
result = evaluator.evaluate_invocations(
|
||||
[inv1_actual, inv2_actual], [inv1_expected, inv2_expected]
|
||||
)
|
||||
assert result.overall_score == 0.5
|
||||
assert result.overall_eval_status == EvalStatus.PASSED
|
||||
assert len(result.per_invocation_results) == 2
|
||||
assert result.per_invocation_results[0].score == 1.0
|
||||
assert result.per_invocation_results[0].eval_status == EvalStatus.PASSED
|
||||
assert result.per_invocation_results[1].score == 0.0
|
||||
assert result.per_invocation_results[1].eval_status == EvalStatus.FAILED
|
||||
|
||||
|
||||
def test_evaluate_invocations_no_invocations(evaluator: TrajectoryEvaluator):
|
||||
"""Tests evaluate_invocations with no invocations."""
|
||||
result = evaluator.evaluate_invocations([], [])
|
||||
assert result.overall_score is None
|
||||
assert result.overall_eval_status == EvalStatus.NOT_EVALUATED
|
||||
assert not result.per_invocation_results
|
||||
|
||||
Reference in New Issue
Block a user