diff --git a/src/google/adk/evaluation/_eval_set_results_manager_utils.py b/src/google/adk/evaluation/_eval_set_results_manager_utils.py index 8505e68d..655c9ec3 100644 --- a/src/google/adk/evaluation/_eval_set_results_manager_utils.py +++ b/src/google/adk/evaluation/_eval_set_results_manager_utils.py @@ -14,8 +14,11 @@ from __future__ import annotations +import json import time +from pydantic import ValidationError + from .eval_result import EvalCaseResult from .eval_result import EvalSetResult @@ -42,3 +45,25 @@ def create_eval_set_result( creation_timestamp=timestamp, ) return eval_set_result + + +def parse_eval_set_result_json( + eval_set_result_json: str | bytes, +) -> EvalSetResult: + """Parses an EvalSetResult from JSON. + + This is backward-compatible with legacy eval set result files that were + double-encoded, where the outer JSON is a string containing the inner JSON + object. + """ + try: + return EvalSetResult.model_validate_json(eval_set_result_json) + except (ValidationError, ValueError) as first_error: + try: + decoded = json.loads(eval_set_result_json) + except json.JSONDecodeError: + raise first_error + + if isinstance(decoded, str): + return EvalSetResult.model_validate_json(decoded) + return EvalSetResult.model_validate(decoded) diff --git a/src/google/adk/evaluation/gcs_eval_set_results_manager.py b/src/google/adk/evaluation/gcs_eval_set_results_manager.py index 860d932f..05e17d22 100644 --- a/src/google/adk/evaluation/gcs_eval_set_results_manager.py +++ b/src/google/adk/evaluation/gcs_eval_set_results_manager.py @@ -22,6 +22,7 @@ from typing_extensions import override from ..errors.not_found_error import NotFoundError from ._eval_set_results_manager_utils import create_eval_set_result +from ._eval_set_results_manager_utils import parse_eval_set_result_json from .eval_result import EvalCaseResult from .eval_result import EvalSetResult from .eval_set_results_manager import EvalSetResultsManager @@ -101,7 +102,7 @@ class GcsEvalSetResultsManager(EvalSetResultsManager): if not blob.exists(): raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.") eval_set_result_data = blob.download_as_text() - return EvalSetResult.model_validate_json(eval_set_result_data) + return parse_eval_set_result_json(eval_set_result_data) @override def list_eval_set_results(self, app_name: str) -> list[str]: diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index d1e597c9..2eddb772 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -14,7 +14,6 @@ from __future__ import annotations -import json import logging import os @@ -22,6 +21,7 @@ from typing_extensions import override from ..errors.not_found_error import NotFoundError from ._eval_set_results_manager_utils import create_eval_set_result +from ._eval_set_results_manager_utils import parse_eval_set_result_json from .eval_result import EvalCaseResult from .eval_result import EvalSetResult from .eval_set_results_manager import EvalSetResultsManager @@ -54,14 +54,13 @@ class LocalEvalSetResultsManager(EvalSetResultsManager): if not os.path.exists(app_eval_history_dir): os.makedirs(app_eval_history_dir) # Convert to json and write to file. - eval_set_result_json = eval_set_result.model_dump_json() eval_set_result_file_path = os.path.join( app_eval_history_dir, eval_set_result.eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION, ) logger.info("Writing eval result to file: %s", eval_set_result_file_path) with open(eval_set_result_file_path, "w", encoding="utf-8") as f: - f.write(json.dumps(eval_set_result_json, indent=2)) + f.write(eval_set_result.model_dump_json(indent=2)) @override def get_eval_set_result( @@ -79,8 +78,8 @@ class LocalEvalSetResultsManager(EvalSetResultsManager): if not os.path.exists(maybe_eval_result_file_path): raise NotFoundError(f"Eval set result `{eval_set_result_id}` not found.") with open(maybe_eval_result_file_path, "r", encoding="utf-8") as file: - eval_result_data = json.load(file) - return EvalSetResult.model_validate_json(eval_result_data) + eval_result_data = file.read() + return parse_eval_set_result_json(eval_result_data) @override def list_eval_set_results(self, app_name: str) -> list[str]: diff --git a/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py b/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py index 7fd0bb97..ab04ace1 100644 --- a/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_gcs_eval_set_results_manager.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + from google.adk.errors.not_found_error import NotFoundError from google.adk.evaluation._eval_set_results_manager_utils import _sanitize_eval_set_result_name from google.adk.evaluation._eval_set_results_manager_utils import create_eval_set_result @@ -165,6 +167,33 @@ class TestGcsEvalSetResultsManager: ) assert retrieved_eval_set_result == eval_set_result + def test_get_eval_set_result_double_encoded_legacy( + self, gcs_eval_set_results_manager, mocker + ): + mocker.patch("time.time", return_value=12345678) + app_name = "test_app" + eval_set_id = "test_eval_set" + eval_case_results = _get_test_eval_case_results() + eval_set_result = create_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + + blob_name = gcs_eval_set_results_manager._get_eval_set_result_blob_name( + app_name, eval_set_result.eval_set_result_id + ) + blob = gcs_eval_set_results_manager.bucket.blob(blob_name) + double_encoded_json = json.dumps(eval_set_result.model_dump_json()) + blob.upload_from_string( + double_encoded_json, content_type="application/json" + ) + + retrieved_eval_set_result = ( + gcs_eval_set_results_manager.get_eval_set_result( + app_name, eval_set_result.eval_set_result_id + ) + ) + assert retrieved_eval_set_result == eval_set_result + def test_list_eval_set_results(self, gcs_eval_set_results_manager, mocker): mocker.patch("time.time", return_value=123) app_name = "test_app" diff --git a/tests/unittests/evaluation/test_local_eval_set_results_manager.py b/tests/unittests/evaluation/test_local_eval_set_results_manager.py index 45500d71..2bec0b64 100644 --- a/tests/unittests/evaluation/test_local_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_local_eval_set_results_manager.py @@ -85,11 +85,12 @@ class TestLocalEvalSetResultsManager: ) assert os.path.exists(expected_file_path) with open(expected_file_path, "r") as f: - actual_eval_set_result_json = json.load(f) + actual_eval_set_result_data = json.load(f) - # need to convert eval_set_result to json - expected_eval_set_result_json = self.eval_set_result.model_dump_json() - assert expected_eval_set_result_json == actual_eval_set_result_json + # Verify the file contains a proper JSON object (not double-encoded) + # Use mode='json' to serialize enums to their values for comparison + expected_eval_set_result_data = self.eval_set_result.model_dump(mode="json") + assert expected_eval_set_result_data == actual_eval_set_result_data def test_get_eval_set_result(self, mocker): mock_time = mocker.patch("time.time") @@ -102,6 +103,24 @@ class TestLocalEvalSetResultsManager: ) assert retrieved_result == self.eval_set_result + def test_get_eval_set_result_double_encoded_legacy(self): + eval_history_dir = os.path.join( + self.agents_dir, self.app_name, _ADK_EVAL_HISTORY_DIR + ) + os.makedirs(eval_history_dir, exist_ok=True) + eval_set_result_file_path = os.path.join( + eval_history_dir, + self.eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION, + ) + double_encoded_json = json.dumps(self.eval_set_result.model_dump_json()) + with open(eval_set_result_file_path, "w", encoding="utf-8") as f: + f.write(double_encoded_json) + + retrieved_result = self.manager.get_eval_set_result( + self.app_name, self.eval_set_result_name + ) + assert retrieved_result == self.eval_set_result + def test_get_eval_set_result_not_found(self, mocker): mock_time = mocker.patch("time.time") mock_time.return_value = self.timestamp