feat: Add implementation of BaseEvalService that runs evals locally

This change:
- Introduces the LocalEvalService Class.
- Implements only the "perform_inference" method. Evaluate method will be implemented in the next CL.
- Adds required test coverage.

PiperOrigin-RevId: 781781954
This commit is contained in:
Ankur Sharma
2025-07-10 19:24:49 -07:00
committed by Copybara-Service
parent 162228d208
commit 51be7a899c
5 changed files with 398 additions and 3 deletions
@@ -0,0 +1,35 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing_extensions import override
from ..agents import BaseAgent
class IdentityAgentCreator:
"""An implementation of the AgentCreator interface that always returns a copy of the root agent."""
def __init__(self, root_agent: BaseAgent):
self._root_agent = root_agent
@override
def get_agent(
self,
) -> BaseAgent:
"""Returns a deep copy of the root agent."""
# TODO: Use Agent.clone() when the PR is merged.
# return self._root_agent.model_copy(deep=True)
return self._root_agent.clone()
+35 -2
View File
@@ -16,6 +16,7 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import AsyncGenerator
from typing import Optional
@@ -56,6 +57,19 @@ class InferenceConfig(BaseModel):
charges.""",
)
parallelism: int = Field(
default=4,
description="""Number of parallel inferences to run during an Eval. Few
factors to consider while changing this value:
1) Your available quota with the model. Models tend to enforce per-minute or
per-second SLAs. Using a larger value could result in the eval quickly consuming
the quota.
2) The tools used by the Agent could also have their SLA. Using a larger value
could also overwhelm those tools.""",
)
class InferenceRequest(BaseModel):
"""Represent a request to perform inferences for the eval cases in an eval set."""
@@ -88,6 +102,14 @@ in an eval set are evaluated.
)
class InferenceStatus(Enum):
"""Status of the inference."""
UNKNOWN = 0
SUCCESS = 1
FAILURE = 2
class InferenceResult(BaseModel):
"""Contains inference results for a single eval case."""
@@ -106,14 +128,25 @@ class InferenceResult(BaseModel):
description="""Id of the eval case for which inferences were generated.""",
)
inferences: list[Invocation] = Field(
description="""Inferences obtained from the Agent for the eval case."""
inferences: Optional[list[Invocation]] = Field(
default=None,
description="""Inferences obtained from the Agent for the eval case.""",
)
session_id: Optional[str] = Field(
description="""Id of the inference session."""
)
status: InferenceStatus = Field(
default=InferenceStatus.UNKNOWN,
description="""Status of the inference.""",
)
error_message: Optional[str] = Field(
default=None,
description="""Error message if the inference failed.""",
)
class EvaluateRequest(BaseModel):
model_config = ConfigDict(
@@ -137,7 +137,7 @@ class EvaluationGenerator:
async def _generate_inferences_from_root_agent(
invocations: list[Invocation],
root_agent: Agent,
reset_func: Any,
reset_func: Optional[Any] = None,
initial_session: Optional[SessionInput] = None,
session_id: Optional[str] = None,
session_service: Optional[BaseSessionService] = None,
@@ -0,0 +1,183 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import logging
from typing import AsyncGenerator
from typing import Callable
from typing import Optional
import uuid
from typing_extensions import override
from ..agents import BaseAgent
from ..artifacts.base_artifact_service import BaseArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..errors.not_found_error import NotFoundError
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils.feature_decorator import working_in_progress
from .base_eval_service import BaseEvalService
from .base_eval_service import EvaluateRequest
from .base_eval_service import InferenceRequest
from .base_eval_service import InferenceResult
from .base_eval_service import InferenceStatus
from .eval_result import EvalCaseResult
from .eval_set import EvalCase
from .eval_set_results_manager import EvalSetResultsManager
from .eval_sets_manager import EvalSetsManager
from .evaluation_generator import EvaluationGenerator
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
from .metric_evaluator_registry import MetricEvaluatorRegistry
logger = logging.getLogger('google_adk.' + __name__)
EVAL_SESSION_ID_PREFIX = '___eval___session___'
def _get_session_id() -> str:
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
@working_in_progress("Incomplete feature, don't use yet")
class LocalEvalService(BaseEvalService):
"""An implementation of BaseEvalService, that runs the evals locally."""
def __init__(
self,
root_agent: BaseAgent,
eval_sets_manager: EvalSetsManager,
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
session_service: BaseSessionService = InMemorySessionService(),
artifact_service: BaseArtifactService = InMemoryArtifactService(),
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
session_id_supplier: Callable[[], str] = _get_session_id,
):
self._root_agent = root_agent
self._eval_sets_manager = eval_sets_manager
self._metric_evaluator_registry = metric_evaluator_registry
self._session_service = session_service
self._artifact_service = artifact_service
self._eval_set_results_manager = eval_set_results_manager
self._session_id_supplier = session_id_supplier
@override
async def perform_inference(
self,
inference_request: InferenceRequest,
) -> AsyncGenerator[InferenceResult, None]:
"""Returns InferenceResult obtained from the Agent as and when they are available.
Args:
inference_request: The request for generating inferences.
"""
# Get the eval set from the storage.
eval_set = self._eval_sets_manager.get_eval_set(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
)
if not eval_set:
raise NotFoundError(
f'Eval set with id {inference_request.eval_set_id} not found for app'
f' {inference_request.app_name}'
)
# Select eval cases for which we need to run inferencing. If the inference
# request specified eval cases, then we use only those.
eval_cases = eval_set.eval_cases
if inference_request.eval_case_ids:
eval_cases = [
eval_case
for eval_case in eval_cases
if eval_case.eval_id in inference_request.eval_case_ids
]
root_agent = self._root_agent.clone()
semaphore = asyncio.Semaphore(
value=inference_request.inference_config.parallelism
)
async def run_inference(eval_case):
async with semaphore:
return await self._perform_inference_sigle_eval_item(
app_name=inference_request.app_name,
eval_set_id=inference_request.eval_set_id,
eval_case=eval_case,
root_agent=root_agent,
)
inference_results = [run_inference(eval_case) for eval_case in eval_cases]
for inference_result in asyncio.as_completed(inference_results):
yield await inference_result
@override
async def evaluate(
self,
evaluate_request: EvaluateRequest,
) -> AsyncGenerator[EvalCaseResult, None]:
"""Returns EvalCaseResult for each item as and when they are available.
Args:
evaluate_request: The request to perform metric evaluations on the
inferences.
"""
raise NotImplementedError()
async def _perform_inference_sigle_eval_item(
self,
app_name: str,
eval_set_id: str,
eval_case: EvalCase,
root_agent: BaseAgent,
) -> InferenceResult:
initial_session = eval_case.session_input
session_id = self._session_id_supplier()
inference_result = InferenceResult(
app_name=app_name,
eval_set_id=eval_set_id,
eval_case_id=eval_case.eval_id,
session_id=session_id,
)
try:
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
)
)
inference_result.inferences = inferences
inference_result.status = InferenceStatus.SUCCESS
return inference_result
except Exception as e:
# We intentionally catch the Exception as we don't failures to affect
# other inferences.
logger.error(
'Inference failed for eval case `%s` with error %s',
eval_case.eval_id,
e,
)
inference_result.status = InferenceStatus.FAILURE
inference_result.error_message = str(e)
return inference_result
@@ -0,0 +1,144 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.agents.llm_agent import LlmAgent
from google.adk.errors.not_found_error import NotFoundError
from google.adk.evaluation.base_eval_service import InferenceConfig
from google.adk.evaluation.base_eval_service import InferenceRequest
from google.adk.evaluation.eval_set import EvalCase
from google.adk.evaluation.eval_set import EvalSet
from google.adk.evaluation.eval_sets_manager import EvalSetsManager
from google.adk.evaluation.local_eval_service import LocalEvalService
from google.adk.models.registry import LLMRegistry
import pytest
@pytest.fixture
def mock_eval_sets_manager():
return mock.create_autospec(EvalSetsManager)
@pytest.fixture
def dummy_agent():
llm = LLMRegistry.new_llm("gemini-pro")
return LlmAgent(name="test_agent", model=llm)
@pytest.fixture
def eval_service(dummy_agent, mock_eval_sets_manager):
return LocalEvalService(
root_agent=dummy_agent,
eval_sets_manager=mock_eval_sets_manager,
)
@pytest.mark.asyncio
async def test_perform_inference_success(
eval_service, dummy_agent, mock_eval_sets_manager
):
eval_set = EvalSet(
eval_set_id="test_eval_set",
eval_cases=[
EvalCase(eval_id="case1", conversation=[], session_input=None),
EvalCase(eval_id="case2", conversation=[], session_input=None),
],
)
mock_eval_sets_manager.get_eval_set.return_value = eval_set
mock_inference_result = mock.MagicMock()
eval_service._perform_inference_sigle_eval_item = mock.AsyncMock(
return_value=mock_inference_result
)
inference_request = InferenceRequest(
app_name="test_app",
eval_set_id="test_eval_set",
inference_config=InferenceConfig(parallelism=2),
)
results = []
async for result in eval_service.perform_inference(inference_request):
results.append(result)
assert len(results) == 2
assert results[0] == mock_inference_result
assert results[1] == mock_inference_result
mock_eval_sets_manager.get_eval_set.assert_called_once_with(
app_name="test_app", eval_set_id="test_eval_set"
)
assert eval_service._perform_inference_sigle_eval_item.call_count == 2
@pytest.mark.asyncio
async def test_perform_inference_with_case_ids(
eval_service, dummy_agent, mock_eval_sets_manager
):
eval_set = EvalSet(
eval_set_id="test_eval_set",
eval_cases=[
EvalCase(eval_id="case1", conversation=[], session_input=None),
EvalCase(eval_id="case2", conversation=[], session_input=None),
EvalCase(eval_id="case3", conversation=[], session_input=None),
],
)
mock_eval_sets_manager.get_eval_set.return_value = eval_set
mock_inference_result = mock.MagicMock()
eval_service._perform_inference_sigle_eval_item = mock.AsyncMock(
return_value=mock_inference_result
)
inference_request = InferenceRequest(
app_name="test_app",
eval_set_id="test_eval_set",
eval_case_ids=["case1", "case3"],
inference_config=InferenceConfig(parallelism=1),
)
results = []
async for result in eval_service.perform_inference(inference_request):
results.append(result)
assert len(results) == 2
eval_service._perform_inference_sigle_eval_item.assert_any_call(
app_name="test_app",
eval_set_id="test_eval_set",
eval_case=eval_set.eval_cases[0],
root_agent=dummy_agent,
)
eval_service._perform_inference_sigle_eval_item.assert_any_call(
app_name="test_app",
eval_set_id="test_eval_set",
eval_case=eval_set.eval_cases[2],
root_agent=dummy_agent,
)
@pytest.mark.asyncio
async def test_perform_inference_eval_set_not_found(
eval_service, mock_eval_sets_manager
):
mock_eval_sets_manager.get_eval_set.return_value = None
inference_request = InferenceRequest(
app_name="test_app",
eval_set_id="not_found_set",
inference_config=InferenceConfig(parallelism=1),
)
with pytest.raises(NotFoundError):
async for _ in eval_service.perform_inference(inference_request):
pass