You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add implementation of BaseEvalService that runs evals locally
This change: - Introduces the LocalEvalService Class. - Implements only the "perform_inference" method. Evaluate method will be implemented in the next CL. - Adds required test coverage. PiperOrigin-RevId: 781781954
This commit is contained in:
committed by
Copybara-Service
parent
162228d208
commit
51be7a899c
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents import BaseAgent
|
||||
|
||||
|
||||
class IdentityAgentCreator:
|
||||
"""An implementation of the AgentCreator interface that always returns a copy of the root agent."""
|
||||
|
||||
def __init__(self, root_agent: BaseAgent):
|
||||
self._root_agent = root_agent
|
||||
|
||||
@override
|
||||
def get_agent(
|
||||
self,
|
||||
) -> BaseAgent:
|
||||
"""Returns a deep copy of the root agent."""
|
||||
# TODO: Use Agent.clone() when the PR is merged.
|
||||
# return self._root_agent.model_copy(deep=True)
|
||||
return self._root_agent.clone()
|
||||
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
@@ -56,6 +57,19 @@ class InferenceConfig(BaseModel):
|
||||
charges.""",
|
||||
)
|
||||
|
||||
parallelism: int = Field(
|
||||
default=4,
|
||||
description="""Number of parallel inferences to run during an Eval. Few
|
||||
factors to consider while changing this value:
|
||||
|
||||
1) Your available quota with the model. Models tend to enforce per-minute or
|
||||
per-second SLAs. Using a larger value could result in the eval quickly consuming
|
||||
the quota.
|
||||
|
||||
2) The tools used by the Agent could also have their SLA. Using a larger value
|
||||
could also overwhelm those tools.""",
|
||||
)
|
||||
|
||||
|
||||
class InferenceRequest(BaseModel):
|
||||
"""Represent a request to perform inferences for the eval cases in an eval set."""
|
||||
@@ -88,6 +102,14 @@ in an eval set are evaluated.
|
||||
)
|
||||
|
||||
|
||||
class InferenceStatus(Enum):
|
||||
"""Status of the inference."""
|
||||
|
||||
UNKNOWN = 0
|
||||
SUCCESS = 1
|
||||
FAILURE = 2
|
||||
|
||||
|
||||
class InferenceResult(BaseModel):
|
||||
"""Contains inference results for a single eval case."""
|
||||
|
||||
@@ -106,14 +128,25 @@ class InferenceResult(BaseModel):
|
||||
description="""Id of the eval case for which inferences were generated.""",
|
||||
)
|
||||
|
||||
inferences: list[Invocation] = Field(
|
||||
description="""Inferences obtained from the Agent for the eval case."""
|
||||
inferences: Optional[list[Invocation]] = Field(
|
||||
default=None,
|
||||
description="""Inferences obtained from the Agent for the eval case.""",
|
||||
)
|
||||
|
||||
session_id: Optional[str] = Field(
|
||||
description="""Id of the inference session."""
|
||||
)
|
||||
|
||||
status: InferenceStatus = Field(
|
||||
default=InferenceStatus.UNKNOWN,
|
||||
description="""Status of the inference.""",
|
||||
)
|
||||
|
||||
error_message: Optional[str] = Field(
|
||||
default=None,
|
||||
description="""Error message if the inference failed.""",
|
||||
)
|
||||
|
||||
|
||||
class EvaluateRequest(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
|
||||
@@ -137,7 +137,7 @@ class EvaluationGenerator:
|
||||
async def _generate_inferences_from_root_agent(
|
||||
invocations: list[Invocation],
|
||||
root_agent: Agent,
|
||||
reset_func: Any,
|
||||
reset_func: Optional[Any] = None,
|
||||
initial_session: Optional[SessionInput] = None,
|
||||
session_id: Optional[str] = None,
|
||||
session_service: Optional[BaseSessionService] = None,
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..agents import BaseAgent
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..utils.feature_decorator import working_in_progress
|
||||
from .base_eval_service import BaseEvalService
|
||||
from .base_eval_service import EvaluateRequest
|
||||
from .base_eval_service import InferenceRequest
|
||||
from .base_eval_service import InferenceResult
|
||||
from .base_eval_service import InferenceStatus
|
||||
from .eval_result import EvalCaseResult
|
||||
from .eval_set import EvalCase
|
||||
from .eval_set_results_manager import EvalSetResultsManager
|
||||
from .eval_sets_manager import EvalSetsManager
|
||||
from .evaluation_generator import EvaluationGenerator
|
||||
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
|
||||
from .metric_evaluator_registry import MetricEvaluatorRegistry
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
EVAL_SESSION_ID_PREFIX = '___eval___session___'
|
||||
|
||||
|
||||
def _get_session_id() -> str:
|
||||
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
|
||||
|
||||
|
||||
@working_in_progress("Incomplete feature, don't use yet")
|
||||
class LocalEvalService(BaseEvalService):
|
||||
"""An implementation of BaseEvalService, that runs the evals locally."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_agent: BaseAgent,
|
||||
eval_sets_manager: EvalSetsManager,
|
||||
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
|
||||
session_service: BaseSessionService = InMemorySessionService(),
|
||||
artifact_service: BaseArtifactService = InMemoryArtifactService(),
|
||||
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
|
||||
session_id_supplier: Callable[[], str] = _get_session_id,
|
||||
):
|
||||
self._root_agent = root_agent
|
||||
self._eval_sets_manager = eval_sets_manager
|
||||
self._metric_evaluator_registry = metric_evaluator_registry
|
||||
self._session_service = session_service
|
||||
self._artifact_service = artifact_service
|
||||
self._eval_set_results_manager = eval_set_results_manager
|
||||
self._session_id_supplier = session_id_supplier
|
||||
|
||||
@override
|
||||
async def perform_inference(
|
||||
self,
|
||||
inference_request: InferenceRequest,
|
||||
) -> AsyncGenerator[InferenceResult, None]:
|
||||
"""Returns InferenceResult obtained from the Agent as and when they are available.
|
||||
|
||||
Args:
|
||||
inference_request: The request for generating inferences.
|
||||
"""
|
||||
# Get the eval set from the storage.
|
||||
eval_set = self._eval_sets_manager.get_eval_set(
|
||||
app_name=inference_request.app_name,
|
||||
eval_set_id=inference_request.eval_set_id,
|
||||
)
|
||||
|
||||
if not eval_set:
|
||||
raise NotFoundError(
|
||||
f'Eval set with id {inference_request.eval_set_id} not found for app'
|
||||
f' {inference_request.app_name}'
|
||||
)
|
||||
|
||||
# Select eval cases for which we need to run inferencing. If the inference
|
||||
# request specified eval cases, then we use only those.
|
||||
eval_cases = eval_set.eval_cases
|
||||
if inference_request.eval_case_ids:
|
||||
eval_cases = [
|
||||
eval_case
|
||||
for eval_case in eval_cases
|
||||
if eval_case.eval_id in inference_request.eval_case_ids
|
||||
]
|
||||
|
||||
root_agent = self._root_agent.clone()
|
||||
|
||||
semaphore = asyncio.Semaphore(
|
||||
value=inference_request.inference_config.parallelism
|
||||
)
|
||||
|
||||
async def run_inference(eval_case):
|
||||
async with semaphore:
|
||||
return await self._perform_inference_sigle_eval_item(
|
||||
app_name=inference_request.app_name,
|
||||
eval_set_id=inference_request.eval_set_id,
|
||||
eval_case=eval_case,
|
||||
root_agent=root_agent,
|
||||
)
|
||||
|
||||
inference_results = [run_inference(eval_case) for eval_case in eval_cases]
|
||||
for inference_result in asyncio.as_completed(inference_results):
|
||||
yield await inference_result
|
||||
|
||||
@override
|
||||
async def evaluate(
|
||||
self,
|
||||
evaluate_request: EvaluateRequest,
|
||||
) -> AsyncGenerator[EvalCaseResult, None]:
|
||||
"""Returns EvalCaseResult for each item as and when they are available.
|
||||
|
||||
Args:
|
||||
evaluate_request: The request to perform metric evaluations on the
|
||||
inferences.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _perform_inference_sigle_eval_item(
|
||||
self,
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
eval_case: EvalCase,
|
||||
root_agent: BaseAgent,
|
||||
) -> InferenceResult:
|
||||
initial_session = eval_case.session_input
|
||||
session_id = self._session_id_supplier()
|
||||
inference_result = InferenceResult(
|
||||
app_name=app_name,
|
||||
eval_set_id=eval_set_id,
|
||||
eval_case_id=eval_case.eval_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
inferences = (
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
invocations=eval_case.conversation,
|
||||
root_agent=root_agent,
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=self._session_service,
|
||||
artifact_service=self._artifact_service,
|
||||
)
|
||||
)
|
||||
|
||||
inference_result.inferences = inferences
|
||||
inference_result.status = InferenceStatus.SUCCESS
|
||||
|
||||
return inference_result
|
||||
except Exception as e:
|
||||
# We intentionally catch the Exception as we don't failures to affect
|
||||
# other inferences.
|
||||
logger.error(
|
||||
'Inference failed for eval case `%s` with error %s',
|
||||
eval_case.eval_id,
|
||||
e,
|
||||
)
|
||||
inference_result.status = InferenceStatus.FAILURE
|
||||
inference_result.error_message = str(e)
|
||||
return inference_result
|
||||
@@ -0,0 +1,144 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.errors.not_found_error import NotFoundError
|
||||
from google.adk.evaluation.base_eval_service import InferenceConfig
|
||||
from google.adk.evaluation.base_eval_service import InferenceRequest
|
||||
from google.adk.evaluation.eval_set import EvalCase
|
||||
from google.adk.evaluation.eval_set import EvalSet
|
||||
from google.adk.evaluation.eval_sets_manager import EvalSetsManager
|
||||
from google.adk.evaluation.local_eval_service import LocalEvalService
|
||||
from google.adk.models.registry import LLMRegistry
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_eval_sets_manager():
|
||||
return mock.create_autospec(EvalSetsManager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_agent():
|
||||
llm = LLMRegistry.new_llm("gemini-pro")
|
||||
return LlmAgent(name="test_agent", model=llm)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eval_service(dummy_agent, mock_eval_sets_manager):
|
||||
return LocalEvalService(
|
||||
root_agent=dummy_agent,
|
||||
eval_sets_manager=mock_eval_sets_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_inference_success(
|
||||
eval_service, dummy_agent, mock_eval_sets_manager
|
||||
):
|
||||
eval_set = EvalSet(
|
||||
eval_set_id="test_eval_set",
|
||||
eval_cases=[
|
||||
EvalCase(eval_id="case1", conversation=[], session_input=None),
|
||||
EvalCase(eval_id="case2", conversation=[], session_input=None),
|
||||
],
|
||||
)
|
||||
mock_eval_sets_manager.get_eval_set.return_value = eval_set
|
||||
|
||||
mock_inference_result = mock.MagicMock()
|
||||
eval_service._perform_inference_sigle_eval_item = mock.AsyncMock(
|
||||
return_value=mock_inference_result
|
||||
)
|
||||
|
||||
inference_request = InferenceRequest(
|
||||
app_name="test_app",
|
||||
eval_set_id="test_eval_set",
|
||||
inference_config=InferenceConfig(parallelism=2),
|
||||
)
|
||||
|
||||
results = []
|
||||
async for result in eval_service.perform_inference(inference_request):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0] == mock_inference_result
|
||||
assert results[1] == mock_inference_result
|
||||
mock_eval_sets_manager.get_eval_set.assert_called_once_with(
|
||||
app_name="test_app", eval_set_id="test_eval_set"
|
||||
)
|
||||
assert eval_service._perform_inference_sigle_eval_item.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_inference_with_case_ids(
|
||||
eval_service, dummy_agent, mock_eval_sets_manager
|
||||
):
|
||||
eval_set = EvalSet(
|
||||
eval_set_id="test_eval_set",
|
||||
eval_cases=[
|
||||
EvalCase(eval_id="case1", conversation=[], session_input=None),
|
||||
EvalCase(eval_id="case2", conversation=[], session_input=None),
|
||||
EvalCase(eval_id="case3", conversation=[], session_input=None),
|
||||
],
|
||||
)
|
||||
mock_eval_sets_manager.get_eval_set.return_value = eval_set
|
||||
|
||||
mock_inference_result = mock.MagicMock()
|
||||
eval_service._perform_inference_sigle_eval_item = mock.AsyncMock(
|
||||
return_value=mock_inference_result
|
||||
)
|
||||
|
||||
inference_request = InferenceRequest(
|
||||
app_name="test_app",
|
||||
eval_set_id="test_eval_set",
|
||||
eval_case_ids=["case1", "case3"],
|
||||
inference_config=InferenceConfig(parallelism=1),
|
||||
)
|
||||
|
||||
results = []
|
||||
async for result in eval_service.perform_inference(inference_request):
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 2
|
||||
eval_service._perform_inference_sigle_eval_item.assert_any_call(
|
||||
app_name="test_app",
|
||||
eval_set_id="test_eval_set",
|
||||
eval_case=eval_set.eval_cases[0],
|
||||
root_agent=dummy_agent,
|
||||
)
|
||||
eval_service._perform_inference_sigle_eval_item.assert_any_call(
|
||||
app_name="test_app",
|
||||
eval_set_id="test_eval_set",
|
||||
eval_case=eval_set.eval_cases[2],
|
||||
root_agent=dummy_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_inference_eval_set_not_found(
|
||||
eval_service, mock_eval_sets_manager
|
||||
):
|
||||
mock_eval_sets_manager.get_eval_set.return_value = None
|
||||
|
||||
inference_request = InferenceRequest(
|
||||
app_name="test_app",
|
||||
eval_set_id="not_found_set",
|
||||
inference_config=InferenceConfig(parallelism=1),
|
||||
)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
async for _ in eval_service.perform_inference(inference_request):
|
||||
pass
|
||||
Reference in New Issue
Block a user