fix: Update create_eval_set API to return the created EvalSet and it route

PiperOrigin-RevId: 797974571
This commit is contained in:
Ankur Sharma
2025-08-21 17:13:15 -07:00
committed by Copybara-Service
parent 157f73181d
commit f660180854
8 changed files with 101 additions and 29 deletions
+34 -6
View File
@@ -45,6 +45,7 @@ from opentelemetry.sdk.trace import TracerProvider
from pydantic import Field
from pydantic import ValidationError
from starlette.types import Lifespan
from typing_extensions import deprecated
from typing_extensions import override
from watchdog.observers import Observer
@@ -66,6 +67,7 @@ from ..evaluation.eval_metrics import EvalMetricResult
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
from ..evaluation.eval_metrics import MetricInfo
from ..evaluation.eval_result import EvalSetResult
from ..evaluation.eval_set import EvalSet
from ..evaluation.eval_set_results_manager import EvalSetResultsManager
from ..evaluation.eval_sets_manager import EvalSetsManager
from ..events.event import Event
@@ -197,6 +199,10 @@ class GetEventGraphResult(common.BaseModel):
dot_src: str
class CreateEvalSetRequest(common.BaseModel):
eval_set: EvalSet
class AdkWebServer:
"""Helper class for setting up and running the ADK web server on FastAPI.
@@ -466,23 +472,45 @@ class AdkWebServer:
)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}",
"/apps/{app_name}/eval-sets",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def create_eval_set(
app_name: str,
eval_set_id: str,
):
"""Creates an eval set, given the id."""
app_name: str, create_eval_set_request: CreateEvalSetRequest
) -> EvalSet:
try:
self.eval_sets_manager.create_eval_set(app_name, eval_set_id)
return self.eval_sets_manager.create_eval_set(
app_name=app_name,
eval_set_id=create_eval_set_request.eval_set.eval_set_id,
)
except ValueError as ve:
raise HTTPException(
status_code=400,
detail=str(ve),
) from ve
@deprecated(
"Please use create_eval_set instead. This will be removed in future"
" releases."
)
@app.post(
"/apps/{app_name}/eval_sets/{eval_set_id}",
response_model_exclude_none=True,
tags=[TAG_EVALUATION],
)
async def create_eval_set_legacy(
app_name: str,
eval_set_id: str,
):
"""Creates an eval set, given the id."""
await create_eval_set(
app_name=app_name,
create_eval_set_request=CreateEvalSetRequest(
eval_set=EvalSet(eval_set_id=eval_set_id, eval_cases=[])
),
)
@app.get(
"/apps/{app_name}/eval_sets",
response_model_exclude_none=True,
+11 -3
View File
@@ -18,7 +18,6 @@ from abc import ABC
from abc import abstractmethod
from typing import Optional
from ..errors.not_found_error import NotFoundError
from .eval_case import EvalCase
from .eval_set import EvalSet
@@ -31,8 +30,17 @@ class EvalSetsManager(ABC):
"""Returns an EvalSet identified by an app_name and eval_set_id."""
@abstractmethod
def create_eval_set(self, app_name: str, eval_set_id: str):
"""Creates an empty EvalSet given the app_name and eval_set_id."""
def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
"""Creates and returns an empty EvalSet given the app_name and eval_set_id.
Raises:
ValueError: If eval set id is not valid or an eval set already exists. A
valid eval set id is string that has one or more of following characters:
- Lower case characters
- Upper case characters
- 0-9
- Underscore
"""
@abstractmethod
def list_eval_sets(self, app_name: str) -> list[str]:
@@ -99,8 +99,12 @@ class GcsEvalSetsManager(EvalSetsManager):
return self._load_eval_set_from_blob(eval_set_blob_name)
@override
def create_eval_set(self, app_name: str, eval_set_id: str):
"""Creates an empty EvalSet and saves it to GCS."""
def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
"""Creates an empty EvalSet and saves it to GCS.
Raises:
ValueError: If eval set id is not valid or an eval set already exists.
"""
self._validate_id(id_name="Eval Set Id", id_value=eval_set_id)
new_eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
if self.bucket.blob(new_eval_set_blob_name).exists():
@@ -115,6 +119,7 @@ class GcsEvalSetsManager(EvalSetsManager):
creation_timestamp=time.time(),
)
self._write_eval_set_to_blob(new_eval_set_blob_name, new_eval_set)
return new_eval_set
@override
def list_eval_sets(self, app_name: str) -> list[str]:
@@ -65,6 +65,7 @@ class InMemoryEvalSetsManager(EvalSetsManager):
)
self._eval_sets[app_name][eval_set_id] = new_eval_set
self._eval_cases[app_name][eval_set_id] = {}
return new_eval_set
@override
def list_eval_sets(self, app_name: str) -> list[str]:
@@ -205,8 +205,12 @@ class LocalEvalSetsManager(EvalSetsManager):
return None
@override
def create_eval_set(self, app_name: str, eval_set_id: str):
"""Creates an empty EvalSet given the app_name and eval_set_id."""
def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
"""Creates and returns an empty EvalSet given the app_name and eval_set_id.
Raises:
ValueError: If eval set id is not valid or an eval set already exists.
"""
self._validate_id(id_name="Eval Set Id", id_value=eval_set_id)
# Define the file path
@@ -224,6 +228,11 @@ class LocalEvalSetsManager(EvalSetsManager):
creation_timestamp=time.time(),
)
self._write_eval_set_to_path(new_eval_set_path, new_eval_set)
return new_eval_set
raise ValueError(
f"EvalSet {eval_set_id} already exists for app {app_name}."
)
@override
def list_eval_sets(self, app_name: str) -> list[str]:
@@ -79,17 +79,21 @@ class TestGcsEvalSetsManager:
app_name, eval_set_id
)
gcs_eval_sets_manager.create_eval_set(app_name, eval_set_id)
created_eval_set = gcs_eval_sets_manager.create_eval_set(
app_name, eval_set_id
)
expected_eval_set = EvalSet(
eval_set_id=eval_set_id,
name=eval_set_id,
eval_cases=[],
creation_timestamp=mocked_time,
)
mock_write_eval_set_to_blob.assert_called_once_with(
eval_set_blob_name,
EvalSet(
eval_set_id=eval_set_id,
name=eval_set_id,
eval_cases=[],
creation_timestamp=mocked_time,
),
expected_eval_set,
)
assert created_eval_set == expected_eval_set
def test_gcs_eval_sets_manager_create_eval_set_invalid_id(
self, gcs_eval_sets_manager
@@ -41,8 +41,7 @@ def eval_case_id():
def test_create_eval_set(manager, app_name, eval_set_id):
manager.create_eval_set(app_name, eval_set_id)
eval_set = manager.get_eval_set(app_name, eval_set_id)
eval_set = manager.create_eval_set(app_name, eval_set_id)
assert eval_set is not None
assert eval_set.eval_set_id == eval_set_id
assert eval_set.eval_cases == []
@@ -370,16 +370,21 @@ class TestLocalEvalSetsManager:
eval_set_id + _EVAL_SET_FILE_EXTENSION,
)
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)
created_eval_set = local_eval_sets_manager.create_eval_set(
app_name, eval_set_id
)
expected_eval_set = EvalSet(
eval_set_id=eval_set_id,
name=eval_set_id,
eval_cases=[],
creation_timestamp=mocked_time,
)
mock_write_eval_set_to_path.assert_called_once_with(
eval_set_file_path,
EvalSet(
eval_set_id=eval_set_id,
name=eval_set_id,
eval_cases=[],
creation_timestamp=mocked_time,
),
expected_eval_set,
)
assert created_eval_set == expected_eval_set
def test_local_eval_sets_manager_create_eval_set_invalid_id(
self, local_eval_sets_manager
@@ -390,6 +395,19 @@ class TestLocalEvalSetsManager:
with pytest.raises(ValueError, match="Invalid Eval Set Id"):
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)
def test_local_eval_sets_manager_create_eval_set_already_exists(
self, local_eval_sets_manager, mocker
):
app_name = "test_app"
eval_set_id = "existing_eval_set_id"
mocker.patch("os.path.exists", return_value=True)
with pytest.raises(
ValueError,
match="EvalSet existing_eval_set_id already exists for app test_app.",
):
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)
def test_local_eval_sets_manager_list_eval_sets_success(
self, local_eval_sets_manager, mocker
):