feat: Improve Tau-bench ADK colab stability

PiperOrigin-RevId: 825675599
This commit is contained in:
Google Team Member
2025-10-29 13:07:37 -07:00
committed by Copybara-Service
parent 592c5d870e
commit 04dbc42e50
4 changed files with 145 additions and 45 deletions
+18 -1
View File
@@ -30,6 +30,7 @@ from google.adk.agents import base_agent
from google.adk.agents import llm_agent
from google.adk.agents import loop_agent
from google.adk.events import event as event_lib
from google.adk.models import google_llm
from google.adk.tools import base_tool
from google.genai import types
@@ -98,6 +99,15 @@ class _Tool(base_tool.BaseTool):
return env_response.observation
def _default_retry_options() -> types.HttpRetryOptions:
return types.HttpRetryOptions(
initial_delay=2,
attempts=4,
max_delay=None,
exp_base=2.0,
)
def _adk_agent(
instruction: str,
tools: list[base_tool.BaseTool],
@@ -120,7 +130,10 @@ def _adk_agent(
# TDOO - Allow more flexibility in configuring the agent used in the loop.
return llm_agent.LlmAgent(
name=name or 'agent',
model=model or 'gemini-2.5-flash',
model=google_llm.Gemini(
model=model or 'gemini-2.5-flash',
retry_options=_default_retry_options(),
),
instruction=instruction,
tools=tools,
generate_content_config=types.GenerateContentConfig(
@@ -130,6 +143,10 @@ def _adk_agent(
mode=types.FunctionCallingConfigMode.VALIDATED
)
),
http_options=types.HttpOptions(
timeout=30000,
retry_options=_default_retry_options(),
),
),
)
+1 -1
View File
@@ -345,4 +345,4 @@ def test_model_name_is_set():
)
mock_runner_cls.assert_called_once()
_, runner_kwargs = mock_runner_cls.call_args
assert runner_kwargs["agent"].sub_agents[0].model == "some-test-model"
assert runner_kwargs["agent"].sub_agents[0].model.model == "some-test-model"
+97 -43
View File
@@ -25,6 +25,8 @@
"%cd ..\n",
"!pip install gepa --quiet\n",
"\n",
"!pip install retry --quiet\n",
"\n",
"%cd tau-bench/"
]
},
@@ -249,16 +251,17 @@
"# - A GEPA adapter that bridges GEPA's optimization process with tau-bench.\n",
"\n",
"\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from datetime import datetime\n",
"import os\n",
"import json\n",
"import random\n",
"import traceback\n",
"import multiprocessing\n",
"import random\n",
"from retry import retry\n",
"import traceback\n",
"from typing import List\n",
"from datetime import datetime\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"\n",
"from google.adk.examples.gepa import tau_bench_agent as tau_bench_agent_lib\n",
"import tau_bench_agent as tau_bench_agent_lib\n",
"from tau_bench.envs import get_env\n",
"from tau_bench.run import display_metrics\n",
"from tau_bench.types import EnvRunResult, RunConfig\n",
@@ -349,42 +352,48 @@
" if config.shuffle:\n",
" random.shuffle(idxs)\n",
"\n",
" def _run(idx: int) -> EnvRunResult:\n",
" @retry(tries=3, delay=10, backoff=2)\n",
" def _run_with_retry(idx: int) -> EnvRunResult:\n",
" isolated_env = get_env(\n",
" config.env,\n",
" user_strategy=config.user_strategy,\n",
" user_model=config.user_model,\n",
" task_split=config.task_split,\n",
" user_provider=config.user_model_provider,\n",
" task_index=idx,\n",
" config.env,\n",
" user_strategy=config.user_strategy,\n",
" user_model=config.user_model,\n",
" task_split=config.task_split,\n",
" user_provider=config.user_model_provider,\n",
" task_index=idx,\n",
" )\n",
" if print_results:\n",
" print(f'Running task {idx}')\n",
" try:\n",
" res = agent.solve(\n",
" res = agent.solve(\n",
" env=isolated_env,\n",
" task_index=idx,\n",
" )\n",
" result = EnvRunResult(\n",
" )\n",
" return EnvRunResult(\n",
" task_id=idx,\n",
" reward=res.reward,\n",
" info=res.info,\n",
" traj=res.messages,\n",
" trial=i,\n",
" )\n",
" )\n",
"\n",
" def _run(idx: int) -> EnvRunResult:\n",
" try:\n",
" result = _run_with_retry(idx)\n",
" except Exception as e:\n",
" logging.warning('Inference error: %s', str(e))\n",
" result = EnvRunResult(\n",
" task_id=idx,\n",
" reward=0.0,\n",
" info={'error': str(e), 'traceback': traceback.format_exc()},\n",
" traj=[],\n",
" trial=i,\n",
" task_id=idx,\n",
" reward=0.0,\n",
" info={'error': str(e), 'traceback': traceback.format_exc()},\n",
" traj=[],\n",
" trial=i,\n",
" )\n",
"\n",
" if print_results:\n",
" print(\n",
" '✅' if result.reward == 1 else '❌',\n",
" f'task_id={idx}',\n",
" # result.info,\n",
" '✅' if result.reward == 1 else '❌',\n",
" f'task_id={idx}',\n",
" # result.info,\n",
" )\n",
" print('-----')\n",
" with lock:\n",
@@ -446,6 +455,26 @@
" task_info: dict\n",
"\n",
"\n",
"def refine_tau_bench_trajectory(traj: list[dict[str, Any]]) -> None:\n",
" \"\"\"Removes unnecessary info from the trajectory, in place.\"\"\"\n",
" for content in traj:\n",
" for part in content[\"parts\"]:\n",
" # Drop all fields that are not populated.\n",
" to_drop = []\n",
" for key in part:\n",
" if not part[key]:\n",
" to_drop.append(key)\n",
" for key in to_drop:\n",
" del part[key]\n",
"\n",
" # For function calls / responses only keep function names, input arguments\n",
" # and outputs.\n",
" if fc := part.get(\"function_call\"):\n",
" part[\"function_call\"] = dict(name=fc[\"name\"], args=fc[\"args\"])\n",
" if fr := part.get(\"function_response\"):\n",
" part[\"function_response\"] = dict(name=fr[\"name\"], args=fr[\"response\"])\n",
"\n",
"\n",
"class TauBenchAdapter(GEPAAdapter[\n",
" TauBenchDataInst,\n",
" TauBenchTrajectory,\n",
@@ -462,7 +491,7 @@
" agent_strategy='tool-calling',\n",
" user_strategy='llm',\n",
" system_instruction_name='system_instruction',\n",
" tool_definitions_name='tool_definitions',\n",
" tools_description: list[dict[str, Any]] | None = None,\n",
" max_concurrency=4,\n",
" ):\n",
" \"\"\"Initializes the TauBenchAdapter.\n",
@@ -476,8 +505,8 @@
" user_strategy: The user simulation strategy (e.g., 'llm').\n",
" system_instruction_name: The key in the candidate dictionary that holds\n",
" the system instruction.\n",
" tool_definitions_name: The key in the candidate dictionary that holds the\n",
" tool definitions.\n",
" tools_description: Describes each of the availble tools. This is used as context\n",
" for the prompt proposer.\n",
" max_concurrency: The maximum number of tasks to run in parallel.\n",
" \"\"\"\n",
" self._agent_model = agent_model\n",
@@ -488,7 +517,7 @@
" self._user_strategy = user_strategy\n",
" self._max_concurrency = max_concurrency\n",
" self._system_instruction_name = system_instruction_name\n",
" self._tool_definitions_name = tool_definitions_name\n",
" self._tools_description = tools_description\n",
"\n",
" def evaluate(\n",
" self,\n",
@@ -544,7 +573,7 @@
" reward=res.reward,\n",
" task_info=res.info))\n",
" result_traj = res.traj\n",
" # TODO - Consider refining the trajectory format.\n",
" refine_tau_bench_trajectory(result_traj)\n",
" trajectories.append(TauBenchTrajectory(result_traj=result_traj))\n",
" scores.append(res.reward)\n",
"\n",
@@ -574,7 +603,13 @@
" data instances for reflection.\n",
" \"\"\"\n",
" system_instruction = candidate[self._system_instruction_name]\n",
" tool_definitions = candidate[self._tool_definitions_name]\n",
"\n",
" tool_definitions = json.dumps(\n",
" self._tools_description,\n",
" indent=2,\n",
" default=str,\n",
" )\n",
"\n",
" inputs = '\\n\\n'.join([\n",
" f'# System Instruction\\n{system_instruction}',\n",
" f'# Tool Definitions\\n{tool_definitions}',\n",
@@ -670,7 +705,6 @@
"]\n",
"\n",
"system_instruction_name = 'system_instruction'\n",
"tool_definitions_name = 'tool_definitions'\n",
"\n",
"SEED_SYSTEM_INSTRUCTION = (\n",
" 'you are a customer support agent helping customers resolve their '\n",
@@ -679,12 +713,6 @@
"\n",
"seed_candidate = {\n",
" system_instruction_name: SEED_SYSTEM_INSTRUCTION,\n",
" # TODO - Consider removing tool definition from optimization space.\n",
" tool_definitions_name: json.dumps(\n",
" tool_definitions_by_domain[tau_bench_env],\n",
" indent=2,\n",
" default=str,\n",
" ),\n",
"}"
]
},
@@ -700,6 +728,7 @@
"# With the configuration and adapter in place, this section creates the adapter\n",
"# instance and calls `gepa.optimize()` to start the Automatic Prompt\n",
"# Optimization (APO) process.\n",
"import litellm\n",
"\n",
"tau_bench_adapter = TauBenchAdapter(\n",
" agent_model=agent_model,\n",
@@ -709,7 +738,7 @@
" agent_strategy='tool-calling',\n",
" user_strategy='llm',\n",
" system_instruction_name=system_instruction_name,\n",
" tool_definitions_name=tool_definitions_name,\n",
" tools_description=tool_definitions_by_domain[tau_bench_env],\n",
" max_concurrency=max_concurrency,\n",
")\n",
"\n",
@@ -720,7 +749,13 @@
" task_lm=None, # this must be None when a custom adapter is used\n",
" adapter=tau_bench_adapter,\n",
" max_metric_calls=max_metric_calls,\n",
" reflection_lm=f'vertex_ai/{reflection_model}',\n",
" reflection_lm = (\n",
" lambda prompt: litellm.completion_with_retries(\n",
" model=f'vertex_ai/{reflection_model}',\n",
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
" num_retries=4, initial_delay=1, max_delay=1,\n",
" ).choices[0].message.content\n",
" ),\n",
" reflection_minibatch_size=reflection_minibatch_size,\n",
")\n",
"list(enumerate(gepa_results.val_aggregate_scores))"
@@ -735,7 +770,6 @@
"outputs": [],
"source": [
"#@title Evaluate All Candidates\n",
"%%time\n",
"\n",
"\n",
"# This is the prompt from https://arxiv.org/pdf/2406.12045\n",
@@ -855,15 +889,35 @@
" )\n",
" system_instruction_to_eval_results[system_instruction] = tau_bench_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "w4Q5hMuERuO6"
},
"outputs": [],
"source": [
"print(gepa_results.best_candidate['system_instruction'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pbG7aBXLRuO6"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": [],
"last_runtime": {
"build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook",
"kind": "private"
}
},
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
@@ -28,6 +28,8 @@ from __future__ import annotations
from typing import Any
import adk_agent
from google.adk.models import llm_response
from google.adk.plugins import base_plugin
from google.genai import types
from tau_bench import envs
from tau_bench import types as tau_bench_types
@@ -64,6 +66,26 @@ def _convert_tool(tool_def: dict[str, Any]) -> types.FunctionDeclaration:
return types.FunctionDeclaration(**tool_def['function'])
_LLM_CALL_ERROR = 'llm_call_error'
class _TauBenchPlugin(base_plugin.BasePlugin):
"""Catches LLM errors and emits event with error code for downstream usage."""
async def on_model_error_callback(
self,
*,
callback_context: base_plugin.CallbackContext,
llm_request: base_plugin.LlmRequest,
error: Exception,
) -> llm_response.LlmResponse:
del callback_context, llm_request # Unused.
return llm_response.LlmResponse(
error_code=_LLM_CALL_ERROR,
error_message=str(error),
)
class _ADKAgent(tool_calling_agent.ToolCallingAgent):
"""ADK agent implementation for Tau Bench."""
@@ -82,6 +104,9 @@ class _ADKAgent(tool_calling_agent.ToolCallingAgent):
Returns:
The result of the solve.
Raises:
- ValueError: If the LLM inference failed.
"""
# Thought-signature is excluded from the message serialization for the
# following reasons:
@@ -102,7 +127,11 @@ class _ADKAgent(tool_calling_agent.ToolCallingAgent):
tools=[_convert_tool(t) for t in env.tools_info],
task_index=task_index,
max_num_steps=max_num_steps,
plugins=[_TauBenchPlugin(name='error_plugin')],
):
if event.error_code == _LLM_CALL_ERROR:
raise ValueError(f'Error {event.error_code=}: {event.error_message=}')
if not event.content:
continue
messages.append(event.content.model_dump(exclude=content_exclusion))