You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Improve Tau-bench ADK colab stability
PiperOrigin-RevId: 825675599
This commit is contained in:
committed by
Copybara-Service
parent
592c5d870e
commit
04dbc42e50
@@ -30,6 +30,7 @@ from google.adk.agents import base_agent
|
||||
from google.adk.agents import llm_agent
|
||||
from google.adk.agents import loop_agent
|
||||
from google.adk.events import event as event_lib
|
||||
from google.adk.models import google_llm
|
||||
from google.adk.tools import base_tool
|
||||
from google.genai import types
|
||||
|
||||
@@ -98,6 +99,15 @@ class _Tool(base_tool.BaseTool):
|
||||
return env_response.observation
|
||||
|
||||
|
||||
def _default_retry_options() -> types.HttpRetryOptions:
|
||||
return types.HttpRetryOptions(
|
||||
initial_delay=2,
|
||||
attempts=4,
|
||||
max_delay=None,
|
||||
exp_base=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _adk_agent(
|
||||
instruction: str,
|
||||
tools: list[base_tool.BaseTool],
|
||||
@@ -120,7 +130,10 @@ def _adk_agent(
|
||||
# TDOO - Allow more flexibility in configuring the agent used in the loop.
|
||||
return llm_agent.LlmAgent(
|
||||
name=name or 'agent',
|
||||
model=model or 'gemini-2.5-flash',
|
||||
model=google_llm.Gemini(
|
||||
model=model or 'gemini-2.5-flash',
|
||||
retry_options=_default_retry_options(),
|
||||
),
|
||||
instruction=instruction,
|
||||
tools=tools,
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
@@ -130,6 +143,10 @@ def _adk_agent(
|
||||
mode=types.FunctionCallingConfigMode.VALIDATED
|
||||
)
|
||||
),
|
||||
http_options=types.HttpOptions(
|
||||
timeout=30000,
|
||||
retry_options=_default_retry_options(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -345,4 +345,4 @@ def test_model_name_is_set():
|
||||
)
|
||||
mock_runner_cls.assert_called_once()
|
||||
_, runner_kwargs = mock_runner_cls.call_args
|
||||
assert runner_kwargs["agent"].sub_agents[0].model == "some-test-model"
|
||||
assert runner_kwargs["agent"].sub_agents[0].model.model == "some-test-model"
|
||||
|
||||
@@ -25,6 +25,8 @@
|
||||
"%cd ..\n",
|
||||
"!pip install gepa --quiet\n",
|
||||
"\n",
|
||||
"!pip install retry --quiet\n",
|
||||
"\n",
|
||||
"%cd tau-bench/"
|
||||
]
|
||||
},
|
||||
@@ -249,16 +251,17 @@
|
||||
"# - A GEPA adapter that bridges GEPA's optimization process with tau-bench.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from concurrent.futures import ThreadPoolExecutor\n",
|
||||
"from datetime import datetime\n",
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"import random\n",
|
||||
"import traceback\n",
|
||||
"import multiprocessing\n",
|
||||
"import random\n",
|
||||
"from retry import retry\n",
|
||||
"import traceback\n",
|
||||
"from typing import List\n",
|
||||
"from datetime import datetime\n",
|
||||
"from concurrent.futures import ThreadPoolExecutor\n",
|
||||
"\n",
|
||||
"from google.adk.examples.gepa import tau_bench_agent as tau_bench_agent_lib\n",
|
||||
"import tau_bench_agent as tau_bench_agent_lib\n",
|
||||
"from tau_bench.envs import get_env\n",
|
||||
"from tau_bench.run import display_metrics\n",
|
||||
"from tau_bench.types import EnvRunResult, RunConfig\n",
|
||||
@@ -349,42 +352,48 @@
|
||||
" if config.shuffle:\n",
|
||||
" random.shuffle(idxs)\n",
|
||||
"\n",
|
||||
" def _run(idx: int) -> EnvRunResult:\n",
|
||||
" @retry(tries=3, delay=10, backoff=2)\n",
|
||||
" def _run_with_retry(idx: int) -> EnvRunResult:\n",
|
||||
" isolated_env = get_env(\n",
|
||||
" config.env,\n",
|
||||
" user_strategy=config.user_strategy,\n",
|
||||
" user_model=config.user_model,\n",
|
||||
" task_split=config.task_split,\n",
|
||||
" user_provider=config.user_model_provider,\n",
|
||||
" task_index=idx,\n",
|
||||
" config.env,\n",
|
||||
" user_strategy=config.user_strategy,\n",
|
||||
" user_model=config.user_model,\n",
|
||||
" task_split=config.task_split,\n",
|
||||
" user_provider=config.user_model_provider,\n",
|
||||
" task_index=idx,\n",
|
||||
" )\n",
|
||||
" if print_results:\n",
|
||||
" print(f'Running task {idx}')\n",
|
||||
" try:\n",
|
||||
" res = agent.solve(\n",
|
||||
" res = agent.solve(\n",
|
||||
" env=isolated_env,\n",
|
||||
" task_index=idx,\n",
|
||||
" )\n",
|
||||
" result = EnvRunResult(\n",
|
||||
" )\n",
|
||||
" return EnvRunResult(\n",
|
||||
" task_id=idx,\n",
|
||||
" reward=res.reward,\n",
|
||||
" info=res.info,\n",
|
||||
" traj=res.messages,\n",
|
||||
" trial=i,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def _run(idx: int) -> EnvRunResult:\n",
|
||||
" try:\n",
|
||||
" result = _run_with_retry(idx)\n",
|
||||
" except Exception as e:\n",
|
||||
" logging.warning('Inference error: %s', str(e))\n",
|
||||
" result = EnvRunResult(\n",
|
||||
" task_id=idx,\n",
|
||||
" reward=0.0,\n",
|
||||
" info={'error': str(e), 'traceback': traceback.format_exc()},\n",
|
||||
" traj=[],\n",
|
||||
" trial=i,\n",
|
||||
" task_id=idx,\n",
|
||||
" reward=0.0,\n",
|
||||
" info={'error': str(e), 'traceback': traceback.format_exc()},\n",
|
||||
" traj=[],\n",
|
||||
" trial=i,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if print_results:\n",
|
||||
" print(\n",
|
||||
" '✅' if result.reward == 1 else '❌',\n",
|
||||
" f'task_id={idx}',\n",
|
||||
" # result.info,\n",
|
||||
" '✅' if result.reward == 1 else '❌',\n",
|
||||
" f'task_id={idx}',\n",
|
||||
" # result.info,\n",
|
||||
" )\n",
|
||||
" print('-----')\n",
|
||||
" with lock:\n",
|
||||
@@ -446,6 +455,26 @@
|
||||
" task_info: dict\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def refine_tau_bench_trajectory(traj: list[dict[str, Any]]) -> None:\n",
|
||||
" \"\"\"Removes unnecessary info from the trajectory, in place.\"\"\"\n",
|
||||
" for content in traj:\n",
|
||||
" for part in content[\"parts\"]:\n",
|
||||
" # Drop all fields that are not populated.\n",
|
||||
" to_drop = []\n",
|
||||
" for key in part:\n",
|
||||
" if not part[key]:\n",
|
||||
" to_drop.append(key)\n",
|
||||
" for key in to_drop:\n",
|
||||
" del part[key]\n",
|
||||
"\n",
|
||||
" # For function calls / responses only keep function names, input arguments\n",
|
||||
" # and outputs.\n",
|
||||
" if fc := part.get(\"function_call\"):\n",
|
||||
" part[\"function_call\"] = dict(name=fc[\"name\"], args=fc[\"args\"])\n",
|
||||
" if fr := part.get(\"function_response\"):\n",
|
||||
" part[\"function_response\"] = dict(name=fr[\"name\"], args=fr[\"response\"])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class TauBenchAdapter(GEPAAdapter[\n",
|
||||
" TauBenchDataInst,\n",
|
||||
" TauBenchTrajectory,\n",
|
||||
@@ -462,7 +491,7 @@
|
||||
" agent_strategy='tool-calling',\n",
|
||||
" user_strategy='llm',\n",
|
||||
" system_instruction_name='system_instruction',\n",
|
||||
" tool_definitions_name='tool_definitions',\n",
|
||||
" tools_description: list[dict[str, Any]] | None = None,\n",
|
||||
" max_concurrency=4,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Initializes the TauBenchAdapter.\n",
|
||||
@@ -476,8 +505,8 @@
|
||||
" user_strategy: The user simulation strategy (e.g., 'llm').\n",
|
||||
" system_instruction_name: The key in the candidate dictionary that holds\n",
|
||||
" the system instruction.\n",
|
||||
" tool_definitions_name: The key in the candidate dictionary that holds the\n",
|
||||
" tool definitions.\n",
|
||||
" tools_description: Describes each of the availble tools. This is used as context\n",
|
||||
" for the prompt proposer.\n",
|
||||
" max_concurrency: The maximum number of tasks to run in parallel.\n",
|
||||
" \"\"\"\n",
|
||||
" self._agent_model = agent_model\n",
|
||||
@@ -488,7 +517,7 @@
|
||||
" self._user_strategy = user_strategy\n",
|
||||
" self._max_concurrency = max_concurrency\n",
|
||||
" self._system_instruction_name = system_instruction_name\n",
|
||||
" self._tool_definitions_name = tool_definitions_name\n",
|
||||
" self._tools_description = tools_description\n",
|
||||
"\n",
|
||||
" def evaluate(\n",
|
||||
" self,\n",
|
||||
@@ -544,7 +573,7 @@
|
||||
" reward=res.reward,\n",
|
||||
" task_info=res.info))\n",
|
||||
" result_traj = res.traj\n",
|
||||
" # TODO - Consider refining the trajectory format.\n",
|
||||
" refine_tau_bench_trajectory(result_traj)\n",
|
||||
" trajectories.append(TauBenchTrajectory(result_traj=result_traj))\n",
|
||||
" scores.append(res.reward)\n",
|
||||
"\n",
|
||||
@@ -574,7 +603,13 @@
|
||||
" data instances for reflection.\n",
|
||||
" \"\"\"\n",
|
||||
" system_instruction = candidate[self._system_instruction_name]\n",
|
||||
" tool_definitions = candidate[self._tool_definitions_name]\n",
|
||||
"\n",
|
||||
" tool_definitions = json.dumps(\n",
|
||||
" self._tools_description,\n",
|
||||
" indent=2,\n",
|
||||
" default=str,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" inputs = '\\n\\n'.join([\n",
|
||||
" f'# System Instruction\\n{system_instruction}',\n",
|
||||
" f'# Tool Definitions\\n{tool_definitions}',\n",
|
||||
@@ -670,7 +705,6 @@
|
||||
"]\n",
|
||||
"\n",
|
||||
"system_instruction_name = 'system_instruction'\n",
|
||||
"tool_definitions_name = 'tool_definitions'\n",
|
||||
"\n",
|
||||
"SEED_SYSTEM_INSTRUCTION = (\n",
|
||||
" 'you are a customer support agent helping customers resolve their '\n",
|
||||
@@ -679,12 +713,6 @@
|
||||
"\n",
|
||||
"seed_candidate = {\n",
|
||||
" system_instruction_name: SEED_SYSTEM_INSTRUCTION,\n",
|
||||
" # TODO - Consider removing tool definition from optimization space.\n",
|
||||
" tool_definitions_name: json.dumps(\n",
|
||||
" tool_definitions_by_domain[tau_bench_env],\n",
|
||||
" indent=2,\n",
|
||||
" default=str,\n",
|
||||
" ),\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
@@ -700,6 +728,7 @@
|
||||
"# With the configuration and adapter in place, this section creates the adapter\n",
|
||||
"# instance and calls `gepa.optimize()` to start the Automatic Prompt\n",
|
||||
"# Optimization (APO) process.\n",
|
||||
"import litellm\n",
|
||||
"\n",
|
||||
"tau_bench_adapter = TauBenchAdapter(\n",
|
||||
" agent_model=agent_model,\n",
|
||||
@@ -709,7 +738,7 @@
|
||||
" agent_strategy='tool-calling',\n",
|
||||
" user_strategy='llm',\n",
|
||||
" system_instruction_name=system_instruction_name,\n",
|
||||
" tool_definitions_name=tool_definitions_name,\n",
|
||||
" tools_description=tool_definitions_by_domain[tau_bench_env],\n",
|
||||
" max_concurrency=max_concurrency,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@@ -720,7 +749,13 @@
|
||||
" task_lm=None, # this must be None when a custom adapter is used\n",
|
||||
" adapter=tau_bench_adapter,\n",
|
||||
" max_metric_calls=max_metric_calls,\n",
|
||||
" reflection_lm=f'vertex_ai/{reflection_model}',\n",
|
||||
" reflection_lm = (\n",
|
||||
" lambda prompt: litellm.completion_with_retries(\n",
|
||||
" model=f'vertex_ai/{reflection_model}',\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" num_retries=4, initial_delay=1, max_delay=1,\n",
|
||||
" ).choices[0].message.content\n",
|
||||
" ),\n",
|
||||
" reflection_minibatch_size=reflection_minibatch_size,\n",
|
||||
")\n",
|
||||
"list(enumerate(gepa_results.val_aggregate_scores))"
|
||||
@@ -735,7 +770,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Evaluate All Candidates\n",
|
||||
"%%time\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# This is the prompt from https://arxiv.org/pdf/2406.12045\n",
|
||||
@@ -855,15 +889,35 @@
|
||||
" )\n",
|
||||
" system_instruction_to_eval_results[system_instruction] = tau_bench_results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "w4Q5hMuERuO6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(gepa_results.best_candidate['system_instruction'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pbG7aBXLRuO6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"last_runtime": {
|
||||
"build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook",
|
||||
"kind": "private"
|
||||
}
|
||||
},
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
|
||||
@@ -28,6 +28,8 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
import adk_agent
|
||||
from google.adk.models import llm_response
|
||||
from google.adk.plugins import base_plugin
|
||||
from google.genai import types
|
||||
from tau_bench import envs
|
||||
from tau_bench import types as tau_bench_types
|
||||
@@ -64,6 +66,26 @@ def _convert_tool(tool_def: dict[str, Any]) -> types.FunctionDeclaration:
|
||||
return types.FunctionDeclaration(**tool_def['function'])
|
||||
|
||||
|
||||
_LLM_CALL_ERROR = 'llm_call_error'
|
||||
|
||||
|
||||
class _TauBenchPlugin(base_plugin.BasePlugin):
|
||||
"""Catches LLM errors and emits event with error code for downstream usage."""
|
||||
|
||||
async def on_model_error_callback(
|
||||
self,
|
||||
*,
|
||||
callback_context: base_plugin.CallbackContext,
|
||||
llm_request: base_plugin.LlmRequest,
|
||||
error: Exception,
|
||||
) -> llm_response.LlmResponse:
|
||||
del callback_context, llm_request # Unused.
|
||||
return llm_response.LlmResponse(
|
||||
error_code=_LLM_CALL_ERROR,
|
||||
error_message=str(error),
|
||||
)
|
||||
|
||||
|
||||
class _ADKAgent(tool_calling_agent.ToolCallingAgent):
|
||||
"""ADK agent implementation for Tau Bench."""
|
||||
|
||||
@@ -82,6 +104,9 @@ class _ADKAgent(tool_calling_agent.ToolCallingAgent):
|
||||
|
||||
Returns:
|
||||
The result of the solve.
|
||||
|
||||
Raises:
|
||||
- ValueError: If the LLM inference failed.
|
||||
"""
|
||||
# Thought-signature is excluded from the message serialization for the
|
||||
# following reasons:
|
||||
@@ -102,7 +127,11 @@ class _ADKAgent(tool_calling_agent.ToolCallingAgent):
|
||||
tools=[_convert_tool(t) for t in env.tools_info],
|
||||
task_index=task_index,
|
||||
max_num_steps=max_num_steps,
|
||||
plugins=[_TauBenchPlugin(name='error_plugin')],
|
||||
):
|
||||
if event.error_code == _LLM_CALL_ERROR:
|
||||
raise ValueError(f'Error {event.error_code=}: {event.error_message=}')
|
||||
|
||||
if not event.content:
|
||||
continue
|
||||
messages.append(event.content.model_dump(exclude=content_exclusion))
|
||||
|
||||
Reference in New Issue
Block a user