diff --git a/src/google/adk/optimization/local_eval_sampler.py b/src/google/adk/optimization/local_eval_sampler.py new file mode 100644 index 00000000..b00c3428 --- /dev/null +++ b/src/google/adk/optimization/local_eval_sampler.py @@ -0,0 +1,367 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any +from typing import Literal +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + +from ..agents.llm_agent import Agent +from ..evaluation.base_eval_service import EvaluateConfig +from ..evaluation.base_eval_service import EvaluateRequest +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.base_eval_service import InferenceResult +from ..evaluation.eval_case import get_all_tool_calls_with_responses +from ..evaluation.eval_case import IntermediateData +from ..evaluation.eval_case import Invocation +from ..evaluation.eval_case import InvocationEvents +from ..evaluation.eval_config import EvalConfig +from ..evaluation.eval_config import get_eval_metrics_from_config +from ..evaluation.eval_metrics import EvalStatus +from ..evaluation.eval_result import EvalCaseResult +from ..evaluation.eval_sets_manager import EvalSetsManager +from ..evaluation.local_eval_service import LocalEvalService +from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider +from ..utils.context_utils import Aclosing +from .data_types import UnstructuredSamplingResult +from .sampler import Sampler + +logger = logging.getLogger("google_adk." + __name__) + + +def _log_eval_summary(eval_results: list[EvalCaseResult]): + """Logs a summary of eval results.""" + num_pass, num_fail, num_other = 0, 0, 0 + for eval_result in eval_results: + eval_result: EvalCaseResult + if eval_result.final_eval_status == EvalStatus.PASSED: + num_pass += 1 + elif eval_result.final_eval_status == EvalStatus.FAILED: + num_fail += 1 + else: + num_other += 1 + log_str = f"Evaluation summary: {num_pass} PASSED, {num_fail} FAILED" + if num_other: + log_str += f", {num_other} OTHER" + logger.info(log_str) + + +def extract_tool_call_data( + intermediate_data: IntermediateData | InvocationEvents, +) -> list[dict[str, Any]]: + """Extracts tool calls and their responses from intermediate data.""" + call_response_pairs = get_all_tool_calls_with_responses(intermediate_data) + result = [] + for tool_call, tool_response in call_response_pairs: + result.append({ + "name": tool_call.name, + "args": tool_call.args, + "response": tool_response.response if tool_response else None, + }) + return result + + +def extract_single_invocation_info( + invocation: Invocation, +) -> dict[str, Any]: + """Extracts useful information from a single invocation.""" + user_prompt = "" + for part in invocation.user_content.parts: + if part.text and not part.thought: + user_prompt += part.text + agent_response = "" + if invocation.final_response: + for part in invocation.final_response.parts: + if part.text and not part.thought: + agent_response += part.text + result = {"user_prompt": user_prompt, "agent_response": agent_response} + if invocation.intermediate_data: + tool_call_data = extract_tool_call_data(invocation.intermediate_data) + result["tool_calls"] = tool_call_data + return result + + +class LocalEvalSamplerConfig(BaseModel): + """Contains configuration options required by the LocalEvalServiceInterface.""" + + eval_config: EvalConfig = Field( + required=True, + description="The configuration for the evaluation.", + ) + + app_name: str = Field( + required=True, + description="The app name to use for evaluation.", + ) + + train_eval_set: str = Field( + required=True, + description="The name of the eval set to use for optimization.", + ) + + train_eval_case_ids: Optional[list[str]] = Field( + default=None, + description=( + "The ids of the eval cases to use for optimization. If not provided," + " all eval cases in the train_eval_set will be used." + ), + ) + + validation_eval_set: Optional[str] = Field( + default=None, + description=( + "The name of the eval set to use for validating the optimized agent." + " If not provided, the train_eval_set will also be used for" + " validation." + ), + ) + + validation_eval_case_ids: Optional[list[str]] = Field( + default=None, + description=( + "The ids of the eval cases to use for validating the optimized agent." + " If not provided, all eval cases in the validation_eval_set will be" + " used. If validation_eval_set is also not provided, all train eval" + " cases will be used." + ), + ) + + +class LocalEvalSampler(Sampler[UnstructuredSamplingResult]): + """Evaluates candidate agents with the ADK's LocalEvalService.""" + + def __init__( + self, + config: LocalEvalSamplerConfig, + eval_sets_manager: EvalSetsManager, + ): + self._config = config + self._eval_sets_manager = eval_sets_manager + + self._train_eval_set = self._config.train_eval_set + self._train_eval_case_ids = ( + self._config.train_eval_case_ids + or self._get_eval_case_ids(self._train_eval_set) + ) + + self._validation_eval_set = ( + self._config.validation_eval_set or self._train_eval_set + ) + if self._config.validation_eval_case_ids: + self._validation_eval_case_ids = self._config.validation_eval_case_ids + elif self._config.validation_eval_set: + self._validation_eval_case_ids = self._get_eval_case_ids( + self._validation_eval_set + ) + else: + self._validation_eval_case_ids = self._train_eval_case_ids + + def _get_selected_example_set_id( + self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET] + ) -> str: + """Returns the ID of the selected example set.""" + return { + Sampler.TRAIN_SET: self._train_eval_set, + Sampler.VALIDATION_SET: self._validation_eval_set, + }[example_set] + + def _get_all_example_ids( + self, example_set: Literal[Sampler.TRAIN_SET, Sampler.VALIDATION_SET] + ) -> list[str]: + """Returns the IDs of all examples in the selected example set.""" + return { + Sampler.TRAIN_SET: self._train_eval_case_ids, + Sampler.VALIDATION_SET: self._validation_eval_case_ids, + }[example_set] + + def _get_eval_case_ids(self, eval_set_id: str) -> list[str]: + """Returns the ids of eval cases in the given eval set.""" + eval_set = self._eval_sets_manager.get_eval_set( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + ) + if eval_set: + return [eval_case.eval_id for eval_case in eval_set.eval_cases] + else: + raise ValueError( + f"Eval set `{eval_set_id}` does not exist for app" + f" `{self._config.app_name}`." + ) + + async def _evaluate_agent( + self, + agent: Agent, + eval_set_id: str, + eval_case_ids: list[str], + ) -> list[EvalCaseResult]: + """Evaluates the agent on the requested eval cases and returns the results. + + Args: + agent: The agent to evaluate. + eval_set_id: The id of the eval set to use for evaluation. + eval_case_ids: The ids of the eval cases to use for evaluation. + + Returns: + A list of EvalCaseResult, one per eval case. + """ + # create the inference request + inference_request = InferenceRequest( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + eval_case_ids=eval_case_ids, + inference_config=InferenceConfig(), + ) + + # create the LocalEvalService + user_simulator_provider = UserSimulatorProvider( + self._config.eval_config.user_simulator_config + ) + eval_service = LocalEvalService( + root_agent=agent, + eval_sets_manager=self._eval_sets_manager, + user_simulator_provider=user_simulator_provider, + ) + + # inference/sampling + async with Aclosing( + eval_service.perform_inference(inference_request=inference_request) + ) as agen: + inference_results: list[InferenceResult] = [ + inference_result async for inference_result in agen + ] + + # evaluation + eval_metrics = get_eval_metrics_from_config(self._config.eval_config) + evaluate_request = EvaluateRequest( + inference_results=inference_results, + evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), + ) + async with Aclosing( + eval_service.evaluate(evaluate_request=evaluate_request) + ) as agen: + eval_results: list[EvalCaseResult] = [ + eval_result async for eval_result in agen + ] + + return eval_results + + def _extract_eval_data( + self, + eval_set_id: str, + eval_results: list[EvalCaseResult], + ) -> dict[str, dict[str, Any]]: + """Extracts evaluation data from the eval results.""" + eval_data = {} + for eval_result in eval_results: + eval_result_dict = {} + eval_case = self._eval_sets_manager.get_eval_case( + app_name=self._config.app_name, + eval_set_id=eval_set_id, + eval_case_id=eval_result.eval_id, + ) + if eval_case and eval_case.conversation_scenario: + eval_result_dict["conversation_scenario"] = ( + eval_case.conversation_scenario + ) + + per_invocation_results = [] + for ( + per_invocation_result + ) in eval_result.eval_metric_result_per_invocation: + eval_metric_results = [] + for eval_metric_result in per_invocation_result.eval_metric_results: + eval_metric_results.append({ + "metric_name": eval_metric_result.metric_name, + "score": round(eval_metric_result.score, 2), # accurate enough + "eval_status": eval_metric_result.eval_status.name, + }) + per_invocation_result_dict = { + "actual_invocation": extract_single_invocation_info( + per_invocation_result.actual_invocation + ), + "eval_metric_results": eval_metric_results, + } + if per_invocation_result.expected_invocation: + per_invocation_result_dict["expected_invocation"] = ( + extract_single_invocation_info( + per_invocation_result.expected_invocation + ) + ) + per_invocation_results.append(per_invocation_result_dict) + eval_result_dict["invocations"] = per_invocation_results + eval_data[eval_result.eval_id] = eval_result_dict + + return eval_data + + def get_train_example_ids(self) -> list[str]: + """Returns the UIDs of examples to use for training the agent.""" + return self._train_eval_case_ids + + def get_validation_example_ids(self) -> list[str]: + """Returns the UIDs of examples to use for validating the optimized agent.""" + return self._validation_eval_case_ids + + async def sample_and_score( + self, + candidate: Agent, + example_set: Literal[ + Sampler.TRAIN_SET, Sampler.VALIDATION_SET + ] = Sampler.VALIDATION_SET, + batch: Optional[list[str]] = None, + capture_full_eval_data: bool = False, + ) -> UnstructuredSamplingResult: + """Evaluates the candidate agent on the batch of examples using the ADK LocalEvalService. + + Args: + candidate: The candidate agent to be evaluated. + example_set: The set of examples to evaluate the candidate agent on. + Possible values are "train" and "validation". + batch: UIDs of examples to evaluate the candidate agent on. If not + provided, all examples from the chosen set will be used. + capture_full_eval_data: If false, it is enough to only calculate the + scores for each example. If true, this method should also capture all + other data required for optimizing the agent (e.g., outputs, + trajectories, and tool calls). + + Returns: + The evaluation results, containing the scores for each example and (if + requested) other data required for optimization. + """ + eval_set_id = self._get_selected_example_set_id(example_set) + if batch is None: + batch = self._get_all_example_ids(example_set) + + eval_results = await self._evaluate_agent(candidate, eval_set_id, batch) + _log_eval_summary(eval_results) + + scores = { + eval_result.eval_id: ( + 1.0 if eval_result.final_eval_status == EvalStatus.PASSED else 0.0 + ) + for eval_result in eval_results + } + + eval_data = ( + self._extract_eval_data(eval_set_id, eval_results) + if capture_full_eval_data + else None + ) + + return UnstructuredSamplingResult(scores=scores, data=eval_data) diff --git a/src/google/adk/optimization/sampler.py b/src/google/adk/optimization/sampler.py index 0a0ff45d..632e5d3d 100644 --- a/src/google/adk/optimization/sampler.py +++ b/src/google/adk/optimization/sampler.py @@ -32,6 +32,9 @@ class Sampler(ABC, Generic[SamplingResult]): to get evaluation results for the candidate agent on the batch of examples. """ + TRAIN_SET = "train" + VALIDATION_SET = "validation" + @abstractmethod def get_train_example_ids(self) -> list[str]: """Returns the UIDs of examples to use for training the agent.""" @@ -46,7 +49,7 @@ class Sampler(ABC, Generic[SamplingResult]): async def sample_and_score( self, candidate: Agent, - example_set: Literal["train", "validation"] = "validation", + example_set: Literal[TRAIN_SET, VALIDATION_SET] = VALIDATION_SET, batch: Optional[list[str]] = None, capture_full_eval_data: bool = False, ) -> SamplingResult: diff --git a/tests/unittests/optimization/local_eval_sampler_test.py b/tests/unittests/optimization/local_eval_sampler_test.py new file mode 100644 index 00000000..6ebd99cb --- /dev/null +++ b/tests/unittests/optimization/local_eval_sampler_test.py @@ -0,0 +1,383 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.agents.llm_agent import Agent +from google.adk.evaluation.base_eval_service import EvaluateConfig +from google.adk.evaluation.base_eval_service import EvaluateRequest +from google.adk.evaluation.base_eval_service import InferenceConfig +from google.adk.evaluation.base_eval_service import InferenceRequest +from google.adk.evaluation.base_eval_service import InferenceResult +from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.eval_case import InvocationEvent +from google.adk.evaluation.eval_case import InvocationEvents +from google.adk.evaluation.eval_config import EvalConfig +from google.adk.evaluation.eval_config import EvalMetric +from google.adk.evaluation.eval_metrics import EvalMetricResult +from google.adk.evaluation.eval_metrics import EvalMetricResultPerInvocation +from google.adk.evaluation.eval_metrics import EvalStatus +from google.adk.evaluation.eval_result import EvalCaseResult +from google.adk.evaluation.eval_sets_manager import EvalSetsManager +from google.adk.optimization.local_eval_sampler import _log_eval_summary +from google.adk.optimization.local_eval_sampler import extract_single_invocation_info +from google.adk.optimization.local_eval_sampler import extract_tool_call_data +from google.adk.optimization.local_eval_sampler import LocalEvalSampler +from google.adk.optimization.local_eval_sampler import LocalEvalSamplerConfig +from google.genai import types +import pytest + + +def test_log_eval_summary(mocker): + statuses = ( + [EvalStatus.PASSED] * 3 + + [EvalStatus.FAILED] * 2 + + [EvalStatus.NOT_EVALUATED] + ) + expected_log = "Evaluation summary: 3 PASSED, 2 FAILED, 1 OTHER" + + eval_results = [ + mocker.MagicMock(spec=EvalCaseResult, final_eval_status=status) + for status in statuses + ] + mock_logger = mocker.patch( + "google.adk.optimization.local_eval_sampler.logger" + ) + + _log_eval_summary(eval_results) + + mock_logger.info.assert_called_once_with(expected_log) + + +def test_extract_tool_call_data(): + # omitting IntermediateData tests as it is no longer used + # case 1: empty invocation events + assert not extract_tool_call_data(InvocationEvents()) + # case 2: multi call invocation events + multi_call_invocation_events = InvocationEvents( + invocation_events=[ + InvocationEvent( + author="agent", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + id="call_1", + name="tool_1", + args={"a": 1}, + ) + ), + types.Part( + function_call=types.FunctionCall( + id="call_2", + name="tool_2", + args={"b": 2}, + ) + ), + types.Part( + function_response=types.FunctionResponse( + id="call_1", + name="tool_1", + response={"result_1": "done"}, + ) + ), + types.Part( + function_response=types.FunctionResponse( + id="call_2", + name="tool_2", + response={"result_2": "done"}, + ) + ), + ] + ), + ) + ] + ) + expected_entries = [ + { + "name": "tool_1", + "args": {"a": 1}, + "response": {"result_1": "done"}, + }, + { + "name": "tool_2", + "args": {"b": 2}, + "response": {"result_2": "done"}, + }, + ] + result = extract_tool_call_data(multi_call_invocation_events) + # order is not guaranteed + for expected_entry in expected_entries: + assert expected_entry in result + assert len(result) == len(expected_entries) + + +def test_extract_single_invocation_info(): + invocation = Invocation( + user_content=types.Content( + parts=[ + types.Part(text="user thought", thought=True), + types.Part(text="Hello agent!"), + ] + ), + final_response=types.Content( + parts=[ + types.Part(text="agent thought", thought=True), + types.Part(text="Hello user!"), + ] + ), + ) + + result = extract_single_invocation_info(invocation) + + assert result == { + "user_prompt": "Hello agent!", + "agent_response": "Hello user!", + } + + +@pytest.mark.parametrize( + "config_kwargs, expected_attrs", + [ + ( + {"train_eval_set": "train_set"}, + { + "_train_eval_set": "train_set", + "_train_eval_case_ids": ["train_set_1", "train_set_2"], + "_validation_eval_set": "train_set", + "_validation_eval_case_ids": ["train_set_1", "train_set_2"], + }, + ), + ( + {"train_eval_set": "train_set", "train_eval_case_ids": ["t1"]}, + { + "_train_eval_case_ids": ["t1"], + "_validation_eval_case_ids": ["t1"], + }, + ), + ( + {"train_eval_set": "train_set", "validation_eval_set": "val_set"}, + { + "_validation_eval_set": "val_set", + "_validation_eval_case_ids": ["val_set_1", "val_set_2"], + }, + ), + ( + {"train_eval_set": "train_set", "validation_eval_case_ids": ["v1"]}, + { + "_validation_eval_case_ids": ["v1"], + }, + ), + ( + { + "train_eval_set": "train_set", + "train_eval_case_ids": ["t1"], + "validation_eval_set": "val_set", + "validation_eval_case_ids": ["v1"], + }, + { + "_train_eval_case_ids": ["t1"], + "_validation_eval_set": "val_set", + "_validation_eval_case_ids": ["v1"], + }, + ), + ], +) +def test_local_eval_service_interface_init( + mocker, config_kwargs, expected_attrs +): + mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager) + + def mock_get_eval_case_ids(self, eval_set_id): + return [f"{eval_set_id}_1", f"{eval_set_id}_2"] + + mocker.patch.object( + LocalEvalSampler, + "_get_eval_case_ids", + autospec=True, + side_effect=mock_get_eval_case_ids, + ) + + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), app_name="test_app", **config_kwargs + ) + interface = LocalEvalSampler(config, mock_eval_sets_manager) + + for attr, expected_value in expected_attrs.items(): + assert getattr(interface, attr) == expected_value + + +@pytest.mark.asyncio +async def test_evaluate_agent(mocker): + # Mocking LocalEvalService and its methods + mock_eval_service_cls = mocker.patch( + "google.adk.optimization.local_eval_sampler.LocalEvalService" + ) + mock_eval_service = mock_eval_service_cls.return_value + + # mocking inference + mock_inference_result = mocker.MagicMock(spec=InferenceResult) + + async def mock_perform_inference(*args, **kwargs): + yield mock_inference_result + + mock_eval_service.perform_inference.side_effect = mock_perform_inference + + # mocking evaluate + mock_eval_case_result = mocker.MagicMock(spec=EvalCaseResult) + + async def mock_evaluate(*args, **kwargs): + yield mock_eval_case_result + + mock_eval_service.evaluate.side_effect = mock_evaluate + + # mocking get_eval_metrics_from_config + mock_metrics = [EvalMetric(metric_name="test_metric")] + mocker.patch( + "google.adk.optimization.local_eval_sampler.get_eval_metrics_from_config", + return_value=mock_metrics, + ) + + mocker.patch("google.adk.evaluation.base_eval_service.EvaluateConfig") + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1"], + ) + interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager)) + + # Call _evaluate_agent + results = await interface._evaluate_agent( + mocker.MagicMock(spec=Agent), "train_set", ["t1"] + ) + + # Assertions + mock_eval_service.perform_inference.assert_called_once_with( + inference_request=InferenceRequest( + app_name="test_app", + eval_set_id="train_set", + eval_case_ids=["t1"], + inference_config=InferenceConfig(), + ) + ) + mock_eval_service.evaluate.assert_called_once_with( + evaluate_request=EvaluateRequest( + inference_results=[mock_inference_result], + evaluate_config=EvaluateConfig(eval_metrics=mock_metrics), + ) + ) + assert results == [mock_eval_case_result] + + +@pytest.mark.asyncio +async def test_extract_eval_data(mocker): + # Mock components + mock_eval_sets_manager = mocker.MagicMock(spec=EvalSetsManager) + mock_eval_case = mocker.MagicMock() + mock_eval_case.conversation_scenario = "test_scenario" + mock_eval_sets_manager.get_eval_case.return_value = mock_eval_case + + # Mock per invocation result + mock_actual_invocation = mocker.MagicMock(spec=Invocation) + mock_expected_invocation = mocker.MagicMock(spec=Invocation) + mock_metric_result = mocker.MagicMock(spec=EvalMetricResult) + mock_metric_result.metric_name = "test_metric" + mock_metric_result.score = 0.854 # should be rounded to 0.85 + mock_metric_result.eval_status = EvalStatus.PASSED + + mock_per_inv_result = mocker.MagicMock(spec=EvalMetricResultPerInvocation) + mock_per_inv_result.actual_invocation = mock_actual_invocation + mock_per_inv_result.expected_invocation = mock_expected_invocation + mock_per_inv_result.eval_metric_results = [mock_metric_result] + + mock_eval_result = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result.eval_id = "t1" + mock_eval_result.eval_metric_result_per_invocation = [mock_per_inv_result] + + # Mock extract_single_invocation_info + mocker.patch( + "google.adk.optimization.local_eval_sampler.extract_single_invocation_info", + side_effect=[{"info": "actual"}, {"info": "expected"}], + ) + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1"], + ) + interface = LocalEvalSampler(config, mock_eval_sets_manager) + + # Call _extract_eval_data + eval_data = interface._extract_eval_data("train_set", [mock_eval_result]) + + # Assertions + assert "t1" in eval_data + assert eval_data["t1"]["conversation_scenario"] == "test_scenario" + assert len(eval_data["t1"]["invocations"]) == 1 + inv = eval_data["t1"]["invocations"][0] + assert inv["actual_invocation"] == {"info": "actual"} + assert inv["expected_invocation"] == {"info": "expected"} + assert inv["eval_metric_results"] == [ + {"metric_name": "test_metric", "score": 0.85, "eval_status": "PASSED"} + ] + + +@pytest.mark.asyncio +async def test_sample_and_score(mocker): + # Mock results + mock_eval_result_1 = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result_1.eval_id = "t1" + mock_eval_result_1.final_eval_status = EvalStatus.PASSED + + mock_eval_result_2 = mocker.MagicMock(spec=EvalCaseResult) + mock_eval_result_2.eval_id = "t2" + mock_eval_result_2.final_eval_status = EvalStatus.FAILED + + eval_results = [mock_eval_result_1, mock_eval_result_2] + + # Initialize Interface + config = LocalEvalSamplerConfig( + eval_config=EvalConfig(), + app_name="test_app", + train_eval_set="train_set", + train_eval_case_ids=["t1", "t2"], + ) + interface = LocalEvalSampler(config, mocker.MagicMock(spec=EvalSetsManager)) + + # Patch internal methods + mocker.patch.object(interface, "_evaluate_agent", return_value=eval_results) + mock_log_summary = mocker.patch( + "google.adk.optimization.local_eval_sampler._log_eval_summary" + ) + mock_extract_data = mocker.patch.object( + interface, "_extract_eval_data", return_value={"t1": {}, "t2": {}} + ) + + # Call sample_and_score + result = await interface.sample_and_score( + mocker.MagicMock(spec=Agent), + example_set="train", + capture_full_eval_data=True, + ) + + # Assertions + assert result.scores == {"t1": 1.0, "t2": 0.0} + assert result.data == {"t1": {}, "t2": {}} + mock_log_summary.assert_called_once_with(eval_results) + mock_extract_data.assert_called_once_with("train_set", eval_results)