diff --git a/contributing/samples/gepa/adk_agent.py b/contributing/samples/gepa/adk_agent.py index b6e49ecf..ad3570dd 100644 --- a/contributing/samples/gepa/adk_agent.py +++ b/contributing/samples/gepa/adk_agent.py @@ -30,6 +30,7 @@ from google.adk.agents import base_agent from google.adk.agents import llm_agent from google.adk.agents import loop_agent from google.adk.events import event as event_lib +from google.adk.models import google_llm from google.adk.tools import base_tool from google.genai import types @@ -98,6 +99,15 @@ class _Tool(base_tool.BaseTool): return env_response.observation +def _default_retry_options() -> types.HttpRetryOptions: + return types.HttpRetryOptions( + initial_delay=2, + attempts=4, + max_delay=None, + exp_base=2.0, + ) + + def _adk_agent( instruction: str, tools: list[base_tool.BaseTool], @@ -120,7 +130,10 @@ def _adk_agent( # TDOO - Allow more flexibility in configuring the agent used in the loop. return llm_agent.LlmAgent( name=name or 'agent', - model=model or 'gemini-2.5-flash', + model=google_llm.Gemini( + model=model or 'gemini-2.5-flash', + retry_options=_default_retry_options(), + ), instruction=instruction, tools=tools, generate_content_config=types.GenerateContentConfig( @@ -130,6 +143,10 @@ def _adk_agent( mode=types.FunctionCallingConfigMode.VALIDATED ) ), + http_options=types.HttpOptions( + timeout=30000, + retry_options=_default_retry_options(), + ), ), ) diff --git a/contributing/samples/gepa/adk_agent_test.py b/contributing/samples/gepa/adk_agent_test.py index a27d1a96..2eea7325 100644 --- a/contributing/samples/gepa/adk_agent_test.py +++ b/contributing/samples/gepa/adk_agent_test.py @@ -345,4 +345,4 @@ def test_model_name_is_set(): ) mock_runner_cls.assert_called_once() _, runner_kwargs = mock_runner_cls.call_args - assert runner_kwargs["agent"].sub_agents[0].model == "some-test-model" + assert runner_kwargs["agent"].sub_agents[0].model.model == "some-test-model" diff --git a/contributing/samples/gepa/gepa_tau_bench.ipynb b/contributing/samples/gepa/gepa_tau_bench.ipynb index 20eff076..46814bf4 100644 --- a/contributing/samples/gepa/gepa_tau_bench.ipynb +++ b/contributing/samples/gepa/gepa_tau_bench.ipynb @@ -25,6 +25,8 @@ "%cd ..\n", "!pip install gepa --quiet\n", "\n", + "!pip install retry --quiet\n", + "\n", "%cd tau-bench/" ] }, @@ -249,16 +251,17 @@ "# - A GEPA adapter that bridges GEPA's optimization process with tau-bench.\n", "\n", "\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "from datetime import datetime\n", "import os\n", "import json\n", - "import random\n", - "import traceback\n", "import multiprocessing\n", + "import random\n", + "from retry import retry\n", + "import traceback\n", "from typing import List\n", - "from datetime import datetime\n", - "from concurrent.futures import ThreadPoolExecutor\n", "\n", - "from google.adk.examples.gepa import tau_bench_agent as tau_bench_agent_lib\n", + "import tau_bench_agent as tau_bench_agent_lib\n", "from tau_bench.envs import get_env\n", "from tau_bench.run import display_metrics\n", "from tau_bench.types import EnvRunResult, RunConfig\n", @@ -349,42 +352,48 @@ " if config.shuffle:\n", " random.shuffle(idxs)\n", "\n", - " def _run(idx: int) -> EnvRunResult:\n", + " @retry(tries=3, delay=10, backoff=2)\n", + " def _run_with_retry(idx: int) -> EnvRunResult:\n", " isolated_env = get_env(\n", - " config.env,\n", - " user_strategy=config.user_strategy,\n", - " user_model=config.user_model,\n", - " task_split=config.task_split,\n", - " user_provider=config.user_model_provider,\n", - " task_index=idx,\n", + " config.env,\n", + " user_strategy=config.user_strategy,\n", + " user_model=config.user_model,\n", + " task_split=config.task_split,\n", + " user_provider=config.user_model_provider,\n", + " task_index=idx,\n", " )\n", " if print_results:\n", " print(f'Running task {idx}')\n", - " try:\n", - " res = agent.solve(\n", + " res = agent.solve(\n", " env=isolated_env,\n", " task_index=idx,\n", - " )\n", - " result = EnvRunResult(\n", + " )\n", + " return EnvRunResult(\n", " task_id=idx,\n", " reward=res.reward,\n", " info=res.info,\n", " traj=res.messages,\n", " trial=i,\n", - " )\n", + " )\n", + "\n", + " def _run(idx: int) -> EnvRunResult:\n", + " try:\n", + " result = _run_with_retry(idx)\n", " except Exception as e:\n", + " logging.warning('Inference error: %s', str(e))\n", " result = EnvRunResult(\n", - " task_id=idx,\n", - " reward=0.0,\n", - " info={'error': str(e), 'traceback': traceback.format_exc()},\n", - " traj=[],\n", - " trial=i,\n", + " task_id=idx,\n", + " reward=0.0,\n", + " info={'error': str(e), 'traceback': traceback.format_exc()},\n", + " traj=[],\n", + " trial=i,\n", " )\n", + "\n", " if print_results:\n", " print(\n", - " '✅' if result.reward == 1 else '❌',\n", - " f'task_id={idx}',\n", - " # result.info,\n", + " '✅' if result.reward == 1 else '❌',\n", + " f'task_id={idx}',\n", + " # result.info,\n", " )\n", " print('-----')\n", " with lock:\n", @@ -446,6 +455,26 @@ " task_info: dict\n", "\n", "\n", + "def refine_tau_bench_trajectory(traj: list[dict[str, Any]]) -> None:\n", + " \"\"\"Removes unnecessary info from the trajectory, in place.\"\"\"\n", + " for content in traj:\n", + " for part in content[\"parts\"]:\n", + " # Drop all fields that are not populated.\n", + " to_drop = []\n", + " for key in part:\n", + " if not part[key]:\n", + " to_drop.append(key)\n", + " for key in to_drop:\n", + " del part[key]\n", + "\n", + " # For function calls / responses only keep function names, input arguments\n", + " # and outputs.\n", + " if fc := part.get(\"function_call\"):\n", + " part[\"function_call\"] = dict(name=fc[\"name\"], args=fc[\"args\"])\n", + " if fr := part.get(\"function_response\"):\n", + " part[\"function_response\"] = dict(name=fr[\"name\"], args=fr[\"response\"])\n", + "\n", + "\n", "class TauBenchAdapter(GEPAAdapter[\n", " TauBenchDataInst,\n", " TauBenchTrajectory,\n", @@ -462,7 +491,7 @@ " agent_strategy='tool-calling',\n", " user_strategy='llm',\n", " system_instruction_name='system_instruction',\n", - " tool_definitions_name='tool_definitions',\n", + " tools_description: list[dict[str, Any]] | None = None,\n", " max_concurrency=4,\n", " ):\n", " \"\"\"Initializes the TauBenchAdapter.\n", @@ -476,8 +505,8 @@ " user_strategy: The user simulation strategy (e.g., 'llm').\n", " system_instruction_name: The key in the candidate dictionary that holds\n", " the system instruction.\n", - " tool_definitions_name: The key in the candidate dictionary that holds the\n", - " tool definitions.\n", + " tools_description: Describes each of the availble tools. This is used as context\n", + " for the prompt proposer.\n", " max_concurrency: The maximum number of tasks to run in parallel.\n", " \"\"\"\n", " self._agent_model = agent_model\n", @@ -488,7 +517,7 @@ " self._user_strategy = user_strategy\n", " self._max_concurrency = max_concurrency\n", " self._system_instruction_name = system_instruction_name\n", - " self._tool_definitions_name = tool_definitions_name\n", + " self._tools_description = tools_description\n", "\n", " def evaluate(\n", " self,\n", @@ -544,7 +573,7 @@ " reward=res.reward,\n", " task_info=res.info))\n", " result_traj = res.traj\n", - " # TODO - Consider refining the trajectory format.\n", + " refine_tau_bench_trajectory(result_traj)\n", " trajectories.append(TauBenchTrajectory(result_traj=result_traj))\n", " scores.append(res.reward)\n", "\n", @@ -574,7 +603,13 @@ " data instances for reflection.\n", " \"\"\"\n", " system_instruction = candidate[self._system_instruction_name]\n", - " tool_definitions = candidate[self._tool_definitions_name]\n", + "\n", + " tool_definitions = json.dumps(\n", + " self._tools_description,\n", + " indent=2,\n", + " default=str,\n", + " )\n", + "\n", " inputs = '\\n\\n'.join([\n", " f'# System Instruction\\n{system_instruction}',\n", " f'# Tool Definitions\\n{tool_definitions}',\n", @@ -670,7 +705,6 @@ "]\n", "\n", "system_instruction_name = 'system_instruction'\n", - "tool_definitions_name = 'tool_definitions'\n", "\n", "SEED_SYSTEM_INSTRUCTION = (\n", " 'you are a customer support agent helping customers resolve their '\n", @@ -679,12 +713,6 @@ "\n", "seed_candidate = {\n", " system_instruction_name: SEED_SYSTEM_INSTRUCTION,\n", - " # TODO - Consider removing tool definition from optimization space.\n", - " tool_definitions_name: json.dumps(\n", - " tool_definitions_by_domain[tau_bench_env],\n", - " indent=2,\n", - " default=str,\n", - " ),\n", "}" ] }, @@ -700,6 +728,7 @@ "# With the configuration and adapter in place, this section creates the adapter\n", "# instance and calls `gepa.optimize()` to start the Automatic Prompt\n", "# Optimization (APO) process.\n", + "import litellm\n", "\n", "tau_bench_adapter = TauBenchAdapter(\n", " agent_model=agent_model,\n", @@ -709,7 +738,7 @@ " agent_strategy='tool-calling',\n", " user_strategy='llm',\n", " system_instruction_name=system_instruction_name,\n", - " tool_definitions_name=tool_definitions_name,\n", + " tools_description=tool_definitions_by_domain[tau_bench_env],\n", " max_concurrency=max_concurrency,\n", ")\n", "\n", @@ -720,7 +749,13 @@ " task_lm=None, # this must be None when a custom adapter is used\n", " adapter=tau_bench_adapter,\n", " max_metric_calls=max_metric_calls,\n", - " reflection_lm=f'vertex_ai/{reflection_model}',\n", + " reflection_lm = (\n", + " lambda prompt: litellm.completion_with_retries(\n", + " model=f'vertex_ai/{reflection_model}',\n", + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", + " num_retries=4, initial_delay=1, max_delay=1,\n", + " ).choices[0].message.content\n", + " ),\n", " reflection_minibatch_size=reflection_minibatch_size,\n", ")\n", "list(enumerate(gepa_results.val_aggregate_scores))" @@ -735,7 +770,6 @@ "outputs": [], "source": [ "#@title Evaluate All Candidates\n", - "%%time\n", "\n", "\n", "# This is the prompt from https://arxiv.org/pdf/2406.12045\n", @@ -855,15 +889,35 @@ " )\n", " system_instruction_to_eval_results[system_instruction] = tau_bench_results" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w4Q5hMuERuO6" + }, + "outputs": [], + "source": [ + "print(gepa_results.best_candidate['system_instruction'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pbG7aBXLRuO6" + }, + "outputs": [], + "source": [] } ], "metadata": { "colab": { - "provenance": [], "last_runtime": { "build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook", "kind": "private" - } + }, + "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", diff --git a/contributing/samples/gepa/tau_bench_agent.py b/contributing/samples/gepa/tau_bench_agent.py index fa2be632..beb78b96 100644 --- a/contributing/samples/gepa/tau_bench_agent.py +++ b/contributing/samples/gepa/tau_bench_agent.py @@ -28,6 +28,8 @@ from __future__ import annotations from typing import Any import adk_agent +from google.adk.models import llm_response +from google.adk.plugins import base_plugin from google.genai import types from tau_bench import envs from tau_bench import types as tau_bench_types @@ -64,6 +66,26 @@ def _convert_tool(tool_def: dict[str, Any]) -> types.FunctionDeclaration: return types.FunctionDeclaration(**tool_def['function']) +_LLM_CALL_ERROR = 'llm_call_error' + + +class _TauBenchPlugin(base_plugin.BasePlugin): + """Catches LLM errors and emits event with error code for downstream usage.""" + + async def on_model_error_callback( + self, + *, + callback_context: base_plugin.CallbackContext, + llm_request: base_plugin.LlmRequest, + error: Exception, + ) -> llm_response.LlmResponse: + del callback_context, llm_request # Unused. + return llm_response.LlmResponse( + error_code=_LLM_CALL_ERROR, + error_message=str(error), + ) + + class _ADKAgent(tool_calling_agent.ToolCallingAgent): """ADK agent implementation for Tau Bench.""" @@ -82,6 +104,9 @@ class _ADKAgent(tool_calling_agent.ToolCallingAgent): Returns: The result of the solve. + + Raises: + - ValueError: If the LLM inference failed. """ # Thought-signature is excluded from the message serialization for the # following reasons: @@ -102,7 +127,11 @@ class _ADKAgent(tool_calling_agent.ToolCallingAgent): tools=[_convert_tool(t) for t in env.tools_info], task_index=task_index, max_num_steps=max_num_steps, + plugins=[_TauBenchPlugin(name='error_plugin')], ): + if event.error_code == _LLM_CALL_ERROR: + raise ValueError(f'Error {event.error_code=}: {event.error_message=}') + if not event.content: continue messages.append(event.content.model_dump(exclude=content_exclusion))