You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Update create_eval_set API to return the created EvalSet and it route
PiperOrigin-RevId: 797974571
This commit is contained in:
committed by
Copybara-Service
parent
157f73181d
commit
f660180854
@@ -45,6 +45,7 @@ from opentelemetry.sdk.trace import TracerProvider
|
||||
from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
from starlette.types import Lifespan
|
||||
from typing_extensions import deprecated
|
||||
from typing_extensions import override
|
||||
from watchdog.observers import Observer
|
||||
|
||||
@@ -66,6 +67,7 @@ from ..evaluation.eval_metrics import EvalMetricResult
|
||||
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
|
||||
from ..evaluation.eval_metrics import MetricInfo
|
||||
from ..evaluation.eval_result import EvalSetResult
|
||||
from ..evaluation.eval_set import EvalSet
|
||||
from ..evaluation.eval_set_results_manager import EvalSetResultsManager
|
||||
from ..evaluation.eval_sets_manager import EvalSetsManager
|
||||
from ..events.event import Event
|
||||
@@ -197,6 +199,10 @@ class GetEventGraphResult(common.BaseModel):
|
||||
dot_src: str
|
||||
|
||||
|
||||
class CreateEvalSetRequest(common.BaseModel):
|
||||
eval_set: EvalSet
|
||||
|
||||
|
||||
class AdkWebServer:
|
||||
"""Helper class for setting up and running the ADK web server on FastAPI.
|
||||
|
||||
@@ -466,23 +472,45 @@ class AdkWebServer:
|
||||
)
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
||||
"/apps/{app_name}/eval-sets",
|
||||
response_model_exclude_none=True,
|
||||
tags=[TAG_EVALUATION],
|
||||
)
|
||||
async def create_eval_set(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
):
|
||||
"""Creates an eval set, given the id."""
|
||||
app_name: str, create_eval_set_request: CreateEvalSetRequest
|
||||
) -> EvalSet:
|
||||
try:
|
||||
self.eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
||||
return self.eval_sets_manager.create_eval_set(
|
||||
app_name=app_name,
|
||||
eval_set_id=create_eval_set_request.eval_set.eval_set_id,
|
||||
)
|
||||
except ValueError as ve:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(ve),
|
||||
) from ve
|
||||
|
||||
@deprecated(
|
||||
"Please use create_eval_set instead. This will be removed in future"
|
||||
" releases."
|
||||
)
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
||||
response_model_exclude_none=True,
|
||||
tags=[TAG_EVALUATION],
|
||||
)
|
||||
async def create_eval_set_legacy(
|
||||
app_name: str,
|
||||
eval_set_id: str,
|
||||
):
|
||||
"""Creates an eval set, given the id."""
|
||||
await create_eval_set(
|
||||
app_name=app_name,
|
||||
create_eval_set_request=CreateEvalSetRequest(
|
||||
eval_set=EvalSet(eval_set_id=eval_set_id, eval_cases=[])
|
||||
),
|
||||
)
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/eval_sets",
|
||||
response_model_exclude_none=True,
|
||||
|
||||
@@ -18,7 +18,6 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from .eval_case import EvalCase
|
||||
from .eval_set import EvalSet
|
||||
|
||||
@@ -31,8 +30,17 @@ class EvalSetsManager(ABC):
|
||||
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
||||
|
||||
@abstractmethod
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||
"""Creates an empty EvalSet given the app_name and eval_set_id."""
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
|
||||
"""Creates and returns an empty EvalSet given the app_name and eval_set_id.
|
||||
|
||||
Raises:
|
||||
ValueError: If eval set id is not valid or an eval set already exists. A
|
||||
valid eval set id is string that has one or more of following characters:
|
||||
- Lower case characters
|
||||
- Upper case characters
|
||||
- 0-9
|
||||
- Underscore
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||
|
||||
@@ -99,8 +99,12 @@ class GcsEvalSetsManager(EvalSetsManager):
|
||||
return self._load_eval_set_from_blob(eval_set_blob_name)
|
||||
|
||||
@override
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||
"""Creates an empty EvalSet and saves it to GCS."""
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
|
||||
"""Creates an empty EvalSet and saves it to GCS.
|
||||
|
||||
Raises:
|
||||
ValueError: If eval set id is not valid or an eval set already exists.
|
||||
"""
|
||||
self._validate_id(id_name="Eval Set Id", id_value=eval_set_id)
|
||||
new_eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
|
||||
if self.bucket.blob(new_eval_set_blob_name).exists():
|
||||
@@ -115,6 +119,7 @@ class GcsEvalSetsManager(EvalSetsManager):
|
||||
creation_timestamp=time.time(),
|
||||
)
|
||||
self._write_eval_set_to_blob(new_eval_set_blob_name, new_eval_set)
|
||||
return new_eval_set
|
||||
|
||||
@override
|
||||
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||
|
||||
@@ -65,6 +65,7 @@ class InMemoryEvalSetsManager(EvalSetsManager):
|
||||
)
|
||||
self._eval_sets[app_name][eval_set_id] = new_eval_set
|
||||
self._eval_cases[app_name][eval_set_id] = {}
|
||||
return new_eval_set
|
||||
|
||||
@override
|
||||
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||
|
||||
@@ -205,8 +205,12 @@ class LocalEvalSetsManager(EvalSetsManager):
|
||||
return None
|
||||
|
||||
@override
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str):
|
||||
"""Creates an empty EvalSet given the app_name and eval_set_id."""
|
||||
def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
|
||||
"""Creates and returns an empty EvalSet given the app_name and eval_set_id.
|
||||
|
||||
Raises:
|
||||
ValueError: If eval set id is not valid or an eval set already exists.
|
||||
"""
|
||||
self._validate_id(id_name="Eval Set Id", id_value=eval_set_id)
|
||||
|
||||
# Define the file path
|
||||
@@ -224,6 +228,11 @@ class LocalEvalSetsManager(EvalSetsManager):
|
||||
creation_timestamp=time.time(),
|
||||
)
|
||||
self._write_eval_set_to_path(new_eval_set_path, new_eval_set)
|
||||
return new_eval_set
|
||||
|
||||
raise ValueError(
|
||||
f"EvalSet {eval_set_id} already exists for app {app_name}."
|
||||
)
|
||||
|
||||
@override
|
||||
def list_eval_sets(self, app_name: str) -> list[str]:
|
||||
|
||||
@@ -79,17 +79,21 @@ class TestGcsEvalSetsManager:
|
||||
app_name, eval_set_id
|
||||
)
|
||||
|
||||
gcs_eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
||||
created_eval_set = gcs_eval_sets_manager.create_eval_set(
|
||||
app_name, eval_set_id
|
||||
)
|
||||
|
||||
expected_eval_set = EvalSet(
|
||||
eval_set_id=eval_set_id,
|
||||
name=eval_set_id,
|
||||
eval_cases=[],
|
||||
creation_timestamp=mocked_time,
|
||||
)
|
||||
mock_write_eval_set_to_blob.assert_called_once_with(
|
||||
eval_set_blob_name,
|
||||
EvalSet(
|
||||
eval_set_id=eval_set_id,
|
||||
name=eval_set_id,
|
||||
eval_cases=[],
|
||||
creation_timestamp=mocked_time,
|
||||
),
|
||||
expected_eval_set,
|
||||
)
|
||||
assert created_eval_set == expected_eval_set
|
||||
|
||||
def test_gcs_eval_sets_manager_create_eval_set_invalid_id(
|
||||
self, gcs_eval_sets_manager
|
||||
|
||||
@@ -41,8 +41,7 @@ def eval_case_id():
|
||||
|
||||
|
||||
def test_create_eval_set(manager, app_name, eval_set_id):
|
||||
manager.create_eval_set(app_name, eval_set_id)
|
||||
eval_set = manager.get_eval_set(app_name, eval_set_id)
|
||||
eval_set = manager.create_eval_set(app_name, eval_set_id)
|
||||
assert eval_set is not None
|
||||
assert eval_set.eval_set_id == eval_set_id
|
||||
assert eval_set.eval_cases == []
|
||||
|
||||
@@ -370,16 +370,21 @@ class TestLocalEvalSetsManager:
|
||||
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
||||
)
|
||||
|
||||
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
||||
created_eval_set = local_eval_sets_manager.create_eval_set(
|
||||
app_name, eval_set_id
|
||||
)
|
||||
|
||||
expected_eval_set = EvalSet(
|
||||
eval_set_id=eval_set_id,
|
||||
name=eval_set_id,
|
||||
eval_cases=[],
|
||||
creation_timestamp=mocked_time,
|
||||
)
|
||||
mock_write_eval_set_to_path.assert_called_once_with(
|
||||
eval_set_file_path,
|
||||
EvalSet(
|
||||
eval_set_id=eval_set_id,
|
||||
name=eval_set_id,
|
||||
eval_cases=[],
|
||||
creation_timestamp=mocked_time,
|
||||
),
|
||||
expected_eval_set,
|
||||
)
|
||||
assert created_eval_set == expected_eval_set
|
||||
|
||||
def test_local_eval_sets_manager_create_eval_set_invalid_id(
|
||||
self, local_eval_sets_manager
|
||||
@@ -390,6 +395,19 @@ class TestLocalEvalSetsManager:
|
||||
with pytest.raises(ValueError, match="Invalid Eval Set Id"):
|
||||
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
||||
|
||||
def test_local_eval_sets_manager_create_eval_set_already_exists(
|
||||
self, local_eval_sets_manager, mocker
|
||||
):
|
||||
app_name = "test_app"
|
||||
eval_set_id = "existing_eval_set_id"
|
||||
mocker.patch("os.path.exists", return_value=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="EvalSet existing_eval_set_id already exists for app test_app.",
|
||||
):
|
||||
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
||||
|
||||
def test_local_eval_sets_manager_list_eval_sets_success(
|
||||
self, local_eval_sets_manager, mocker
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user