You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Fix double JSON encoding when saving eval set results
The `model_dump_json()` method already returns a JSON string, so wrapping it in `json.dumps()` was causing double encoding Close #3993 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 852455534
This commit is contained in:
committed by
Copybara-Service
parent
bfed19cd78
commit
fc4e3d6f60
@@ -14,8 +14,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .eval_result import EvalCaseResult
|
||||
from .eval_result import EvalSetResult
|
||||
|
||||
@@ -42,3 +45,25 @@ def create_eval_set_result(
|
||||
creation_timestamp=timestamp,
|
||||
)
|
||||
return eval_set_result
|
||||
|
||||
|
||||
def parse_eval_set_result_json(
|
||||
eval_set_result_json: str | bytes,
|
||||
) -> EvalSetResult:
|
||||
"""Parses an EvalSetResult from JSON.
|
||||
|
||||
This is backward-compatible with legacy eval set result files that were
|
||||
double-encoded, where the outer JSON is a string containing the inner JSON
|
||||
object.
|
||||
"""
|
||||
try:
|
||||
return EvalSetResult.model_validate_json(eval_set_result_json)
|
||||
except (ValidationError, ValueError) as first_error:
|
||||
try:
|
||||
decoded = json.loads(eval_set_result_json)
|
||||
except json.JSONDecodeError:
|
||||
raise first_error
|
||||
|
||||
if isinstance(decoded, str):
|
||||
return EvalSetResult.model_validate_json(decoded)
|
||||
return EvalSetResult.model_validate(decoded)
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing_extensions import override
|
||||
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ._eval_set_results_manager_utils import create_eval_set_result
|
||||
from ._eval_set_results_manager_utils import parse_eval_set_result_json
|
||||
from .eval_result import EvalCaseResult
|
||||
from .eval_result import EvalSetResult
|
||||
from .eval_set_results_manager import EvalSetResultsManager
|
||||
@@ -101,7 +102,7 @@ class GcsEvalSetResultsManager(EvalSetResultsManager):
|
||||
if not blob.exists():
|
||||
raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.")
|
||||
eval_set_result_data = blob.download_as_text()
|
||||
return EvalSetResult.model_validate_json(eval_set_result_data)
|
||||
return parse_eval_set_result_json(eval_set_result_data)
|
||||
|
||||
@override
|
||||
def list_eval_set_results(self, app_name: str) -> list[str]:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -22,6 +21,7 @@ from typing_extensions import override
|
||||
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ._eval_set_results_manager_utils import create_eval_set_result
|
||||
from ._eval_set_results_manager_utils import parse_eval_set_result_json
|
||||
from .eval_result import EvalCaseResult
|
||||
from .eval_result import EvalSetResult
|
||||
from .eval_set_results_manager import EvalSetResultsManager
|
||||
@@ -54,14 +54,13 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
|
||||
if not os.path.exists(app_eval_history_dir):
|
||||
os.makedirs(app_eval_history_dir)
|
||||
# Convert to json and write to file.
|
||||
eval_set_result_json = eval_set_result.model_dump_json()
|
||||
eval_set_result_file_path = os.path.join(
|
||||
app_eval_history_dir,
|
||||
eval_set_result.eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION,
|
||||
)
|
||||
logger.info("Writing eval result to file: %s", eval_set_result_file_path)
|
||||
with open(eval_set_result_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(eval_set_result_json, indent=2))
|
||||
f.write(eval_set_result.model_dump_json(indent=2))
|
||||
|
||||
@override
|
||||
def get_eval_set_result(
|
||||
@@ -79,8 +78,8 @@ class LocalEvalSetResultsManager(EvalSetResultsManager):
|
||||
if not os.path.exists(maybe_eval_result_file_path):
|
||||
raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.")
|
||||
with open(maybe_eval_result_file_path, "r", encoding="utf-8") as file:
|
||||
eval_result_data = json.load(file)
|
||||
return EvalSetResult.model_validate_json(eval_result_data)
|
||||
eval_result_data = file.read()
|
||||
return parse_eval_set_result_json(eval_result_data)
|
||||
|
||||
@override
|
||||
def list_eval_set_results(self, app_name: str) -> list[str]:
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
from google.adk.errors.not_found_error import NotFoundError
|
||||
from google.adk.evaluation._eval_set_results_manager_utils import _sanitize_eval_set_result_name
|
||||
from google.adk.evaluation._eval_set_results_manager_utils import create_eval_set_result
|
||||
@@ -165,6 +167,33 @@ class TestGcsEvalSetResultsManager:
|
||||
)
|
||||
assert retrieved_eval_set_result == eval_set_result
|
||||
|
||||
def test_get_eval_set_result_double_encoded_legacy(
|
||||
self, gcs_eval_set_results_manager, mocker
|
||||
):
|
||||
mocker.patch("time.time", return_value=12345678)
|
||||
app_name = "test_app"
|
||||
eval_set_id = "test_eval_set"
|
||||
eval_case_results = _get_test_eval_case_results()
|
||||
eval_set_result = create_eval_set_result(
|
||||
app_name, eval_set_id, eval_case_results
|
||||
)
|
||||
|
||||
blob_name = gcs_eval_set_results_manager._get_eval_set_result_blob_name(
|
||||
app_name, eval_set_result.eval_set_result_id
|
||||
)
|
||||
blob = gcs_eval_set_results_manager.bucket.blob(blob_name)
|
||||
double_encoded_json = json.dumps(eval_set_result.model_dump_json())
|
||||
blob.upload_from_string(
|
||||
double_encoded_json, content_type="application/json"
|
||||
)
|
||||
|
||||
retrieved_eval_set_result = (
|
||||
gcs_eval_set_results_manager.get_eval_set_result(
|
||||
app_name, eval_set_result.eval_set_result_id
|
||||
)
|
||||
)
|
||||
assert retrieved_eval_set_result == eval_set_result
|
||||
|
||||
def test_list_eval_set_results(self, gcs_eval_set_results_manager, mocker):
|
||||
mocker.patch("time.time", return_value=123)
|
||||
app_name = "test_app"
|
||||
|
||||
@@ -85,11 +85,12 @@ class TestLocalEvalSetResultsManager:
|
||||
)
|
||||
assert os.path.exists(expected_file_path)
|
||||
with open(expected_file_path, "r") as f:
|
||||
actual_eval_set_result_json = json.load(f)
|
||||
actual_eval_set_result_data = json.load(f)
|
||||
|
||||
# need to convert eval_set_result to json
|
||||
expected_eval_set_result_json = self.eval_set_result.model_dump_json()
|
||||
assert expected_eval_set_result_json == actual_eval_set_result_json
|
||||
# Verify the file contains a proper JSON object (not double-encoded)
|
||||
# Use mode='json' to serialize enums to their values for comparison
|
||||
expected_eval_set_result_data = self.eval_set_result.model_dump(mode="json")
|
||||
assert expected_eval_set_result_data == actual_eval_set_result_data
|
||||
|
||||
def test_get_eval_set_result(self, mocker):
|
||||
mock_time = mocker.patch("time.time")
|
||||
@@ -102,6 +103,24 @@ class TestLocalEvalSetResultsManager:
|
||||
)
|
||||
assert retrieved_result == self.eval_set_result
|
||||
|
||||
def test_get_eval_set_result_double_encoded_legacy(self):
|
||||
eval_history_dir = os.path.join(
|
||||
self.agents_dir, self.app_name, _ADK_EVAL_HISTORY_DIR
|
||||
)
|
||||
os.makedirs(eval_history_dir, exist_ok=True)
|
||||
eval_set_result_file_path = os.path.join(
|
||||
eval_history_dir,
|
||||
self.eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION,
|
||||
)
|
||||
double_encoded_json = json.dumps(self.eval_set_result.model_dump_json())
|
||||
with open(eval_set_result_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(double_encoded_json)
|
||||
|
||||
retrieved_result = self.manager.get_eval_set_result(
|
||||
self.app_name, self.eval_set_result_name
|
||||
)
|
||||
assert retrieved_result == self.eval_set_result
|
||||
|
||||
def test_get_eval_set_result_not_found(self, mocker):
|
||||
mock_time = mocker.patch("time.time")
|
||||
mock_time.return_value = self.timestamp
|
||||
|
||||
Reference in New Issue
Block a user