You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Remove deprecated run_evals from cli_eval.py
This change removes the `run_evals` function and its helper `_get_evaluator` from `cli_eval.py`, as they were marked as deprecated. Corresponding test mocks and patches in `test_fast_api.py` are also removed. PiperOrigin-RevId: 818719422
This commit is contained in:
committed by
Copybara-Service
parent
e212ff558e
commit
348e552ba6
@@ -72,6 +72,7 @@ from ..evaluation.eval_case import SessionInput
|
||||
from ..evaluation.eval_metrics import EvalMetric
|
||||
from ..evaluation.eval_metrics import EvalMetricResult
|
||||
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
|
||||
from ..evaluation.eval_metrics import EvalStatus
|
||||
from ..evaluation.eval_metrics import MetricInfo
|
||||
from ..evaluation.eval_result import EvalSetResult
|
||||
from ..evaluation.eval_set import EvalSet
|
||||
@@ -85,7 +86,6 @@ from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import cleanup
|
||||
from .utils import common
|
||||
from .utils import envs
|
||||
|
||||
@@ -15,42 +15,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
import click
|
||||
from google.genai import types as genai_types
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..evaluation.base_eval_service import BaseEvalService
|
||||
from ..evaluation.base_eval_service import EvaluateConfig
|
||||
from ..evaluation.base_eval_service import EvaluateRequest
|
||||
from ..evaluation.base_eval_service import InferenceConfig
|
||||
from ..evaluation.base_eval_service import InferenceRequest
|
||||
from ..evaluation.base_eval_service import InferenceResult
|
||||
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from ..evaluation.eval_case import EvalCase
|
||||
from ..evaluation.eval_case import get_all_tool_calls
|
||||
from ..evaluation.eval_case import IntermediateDataType
|
||||
from ..evaluation.eval_config import BaseCriterion
|
||||
from ..evaluation.eval_config import EvalConfig
|
||||
from ..evaluation.eval_metrics import EvalMetric
|
||||
from ..evaluation.eval_metrics import EvalMetricResult
|
||||
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
|
||||
from ..evaluation.eval_metrics import JudgeModelOptions
|
||||
from ..evaluation.eval_result import EvalCaseResult
|
||||
from ..evaluation.eval_sets_manager import EvalSetsManager
|
||||
from ..evaluation.evaluator import EvalStatus
|
||||
from ..evaluation.evaluator import Evaluator
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..utils.context_utils import Aclosing
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
@@ -172,147 +157,6 @@ async def _collect_eval_results(
|
||||
return eval_results
|
||||
|
||||
|
||||
@deprecated(
|
||||
"This method is deprecated and will be removed in fututre release. Please"
|
||||
" use LocalEvalService to define your custom evals."
|
||||
)
|
||||
async def run_evals(
|
||||
eval_cases_by_eval_set_id: dict[str, list[EvalCase]],
|
||||
root_agent: Agent,
|
||||
reset_func: Optional[Any],
|
||||
eval_metrics: list[EvalMetric],
|
||||
session_service: Optional[BaseSessionService] = None,
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
) -> AsyncGenerator[EvalCaseResult, None]:
|
||||
"""Returns a stream of EvalCaseResult for each eval case that was evaluated.
|
||||
|
||||
Args:
|
||||
eval_cases_by_eval_set_id: Eval cases categorized by eval set id to which
|
||||
they belong.
|
||||
root_agent: Agent to use for inferencing.
|
||||
reset_func: If present, this will be called before invoking the agent before
|
||||
every inferencing step.
|
||||
eval_metrics: A list of metrics that should be used during evaluation.
|
||||
session_service: The session service to use during inferencing.
|
||||
artifact_service: The artifact service to use during inferencing.
|
||||
"""
|
||||
try:
|
||||
from ..evaluation.evaluation_generator import EvaluationGenerator
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
|
||||
for eval_set_id, eval_cases in eval_cases_by_eval_set_id.items():
|
||||
for eval_case in eval_cases:
|
||||
eval_name = eval_case.eval_id
|
||||
initial_session = eval_case.session_input
|
||||
user_id = initial_session.user_id if initial_session else "test_user_id"
|
||||
|
||||
try:
|
||||
print(f"Running Eval: {eval_set_id}:{eval_name}")
|
||||
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"
|
||||
|
||||
inference_result = (
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
invocations=eval_case.conversation,
|
||||
root_agent=root_agent,
|
||||
reset_func=reset_func,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=session_service,
|
||||
artifact_service=artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize the per-invocation metric results to an empty list.
|
||||
# We will fill this as we evaluate each metric.
|
||||
eval_metric_result_per_invocation = []
|
||||
for actual, expected in zip(inference_result, eval_case.conversation):
|
||||
eval_metric_result_per_invocation.append(
|
||||
EvalMetricResultPerInvocation(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
eval_metric_results=[],
|
||||
)
|
||||
)
|
||||
|
||||
overall_eval_metric_results = []
|
||||
|
||||
for eval_metric in eval_metrics:
|
||||
metric_evaluator = _get_evaluator(eval_metric)
|
||||
|
||||
if inspect.iscoroutinefunction(metric_evaluator.evaluate_invocations):
|
||||
evaluation_result = await metric_evaluator.evaluate_invocations(
|
||||
actual_invocations=inference_result,
|
||||
expected_invocations=eval_case.conversation,
|
||||
)
|
||||
else:
|
||||
evaluation_result = metric_evaluator.evaluate_invocations(
|
||||
actual_invocations=inference_result,
|
||||
expected_invocations=eval_case.conversation,
|
||||
)
|
||||
|
||||
overall_eval_metric_results.append(
|
||||
EvalMetricResult(
|
||||
metric_name=eval_metric.metric_name,
|
||||
threshold=eval_metric.threshold,
|
||||
score=evaluation_result.overall_score,
|
||||
eval_status=evaluation_result.overall_eval_status,
|
||||
)
|
||||
)
|
||||
for index, per_invocation_result in enumerate(
|
||||
evaluation_result.per_invocation_results
|
||||
):
|
||||
eval_metric_result_per_invocation[index].eval_metric_results.append(
|
||||
EvalMetricResult(
|
||||
metric_name=eval_metric.metric_name,
|
||||
threshold=eval_metric.threshold,
|
||||
score=per_invocation_result.score,
|
||||
eval_status=per_invocation_result.eval_status,
|
||||
)
|
||||
)
|
||||
|
||||
final_eval_status = EvalStatus.NOT_EVALUATED
|
||||
# Go over the all the eval statuses and mark the final eval status as
|
||||
# passed if all of them pass, otherwise mark the final eval status to
|
||||
# failed.
|
||||
for overall_eval_metric_result in overall_eval_metric_results:
|
||||
overall_eval_status = overall_eval_metric_result.eval_status
|
||||
if overall_eval_status == EvalStatus.PASSED:
|
||||
final_eval_status = EvalStatus.PASSED
|
||||
elif overall_eval_status == EvalStatus.NOT_EVALUATED:
|
||||
continue
|
||||
elif overall_eval_status == EvalStatus.FAILED:
|
||||
final_eval_status = EvalStatus.FAILED
|
||||
break
|
||||
else:
|
||||
raise ValueError("Unknown eval status.")
|
||||
|
||||
yield EvalCaseResult(
|
||||
eval_set_file=eval_set_id,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_id=eval_name,
|
||||
final_eval_status=final_eval_status,
|
||||
eval_metric_results=[],
|
||||
overall_eval_metric_results=overall_eval_metric_results,
|
||||
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if final_eval_status == EvalStatus.PASSED:
|
||||
result = "✅ Passed"
|
||||
else:
|
||||
result = "❌ Failed"
|
||||
|
||||
print(f"Result: {result}\n")
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
except Exception:
|
||||
# Catching the general exception, so that we don't block other eval
|
||||
# cases.
|
||||
logger.exception("Eval failed for `%s:%s`", eval_set_id, eval_name)
|
||||
|
||||
|
||||
def _convert_content_to_text(
|
||||
content: Optional[genai_types.Content],
|
||||
) -> str:
|
||||
@@ -413,32 +257,6 @@ def pretty_print_eval_result(eval_result: EvalCaseResult):
|
||||
click.echo("\n\n") # Few empty lines for visual clarity
|
||||
|
||||
|
||||
def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
|
||||
try:
|
||||
from ..evaluation.final_response_match_v2 import FinalResponseMatchV2Evaluator
|
||||
from ..evaluation.response_evaluator import ResponseEvaluator
|
||||
from ..evaluation.safety_evaluator import SafetyEvaluatorV1
|
||||
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
|
||||
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
|
||||
return TrajectoryEvaluator(threshold=eval_metric.threshold)
|
||||
elif (
|
||||
eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
|
||||
or eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY
|
||||
):
|
||||
return ResponseEvaluator(
|
||||
threshold=eval_metric.threshold, metric_name=eval_metric.metric_name
|
||||
)
|
||||
elif eval_metric.metric_name == SAFETY_V1_KEY:
|
||||
return SafetyEvaluatorV1(eval_metric)
|
||||
elif eval_metric.metric_name == FINAL_RESPONSE_MATCH_V2:
|
||||
eval_metric.judge_model_options = JudgeModelOptions()
|
||||
return FinalResponseMatchV2Evaluator(eval_metric)
|
||||
|
||||
raise ValueError(f"Unsupported eval metric: {eval_metric}")
|
||||
|
||||
|
||||
def get_eval_sets_manager(
|
||||
eval_storage_uri: Optional[str], agents_dir: str
|
||||
) -> EvalSetsManager:
|
||||
|
||||
@@ -153,28 +153,6 @@ class _MockEvalCaseResult(BaseModel):
|
||||
eval_metric_result_per_invocation: list = {}
|
||||
|
||||
|
||||
# Mock for the run_evals function, tailored for test_run_eval
|
||||
async def mock_run_evals_for_fast_api(*args, **kwargs):
|
||||
# This is what the test_run_eval expects for its assertions
|
||||
yield _MockEvalCaseResult(
|
||||
eval_set_id="test_eval_set_id", # Matches expected in verify_eval_case_result
|
||||
eval_id="test_eval_case_id", # Matches expected
|
||||
final_eval_status=1, # Matches expected (assuming 1 is PASSED)
|
||||
user_id="test_user", # Placeholder, adapt if needed
|
||||
session_id="test_session_for_eval_case", # Placeholder
|
||||
eval_set_file="test_eval_set_file", # Placeholder
|
||||
overall_eval_metric_results=[{ # Matches expected
|
||||
"metricName": "tool_trajectory_avg_score",
|
||||
"threshold": 0.5,
|
||||
"score": 1.0,
|
||||
"evalStatus": 1,
|
||||
}],
|
||||
# Provide other fields if RunEvalResult or subsequent processing needs them
|
||||
eval_metric_results=[],
|
||||
eval_metric_result_per_invocation=[],
|
||||
)
|
||||
|
||||
|
||||
#################################################
|
||||
# Test Fixtures
|
||||
#################################################
|
||||
@@ -453,10 +431,6 @@ def test_app(
|
||||
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
|
||||
return_value=mock_eval_set_results_manager,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.cli_eval.run_evals", # Patch where it's imported in fast_api.py
|
||||
new=mock_run_evals_for_fast_api,
|
||||
),
|
||||
):
|
||||
# Get the FastAPI app, but don't actually run it
|
||||
app = get_fast_api_app(
|
||||
@@ -604,10 +578,6 @@ def test_app_with_a2a(
|
||||
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
|
||||
return_value=mock_eval_set_results_manager,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.cli_eval.run_evals",
|
||||
new=mock_run_evals_for_fast_api,
|
||||
),
|
||||
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store,
|
||||
patch(
|
||||
"google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"
|
||||
|
||||
Reference in New Issue
Block a user