feat: Some small infra fixes to the gepa demo colab

PiperOrigin-RevId: 829240716
This commit is contained in:
Google Team Member
2025-11-06 20:43:48 -08:00
committed by Copybara-Service
parent d118479ccf
commit a0cf97eba2
5 changed files with 76 additions and 50 deletions
+2 -37
View File
@@ -26,13 +26,11 @@ import os
import random
import traceback
from typing import Any
from typing import Callable
from typing import TypedDict
import gepa
from gepa.core.adapter import EvaluationBatch
from gepa.core.adapter import GEPAAdapter
from google.genai import types
from litellm import provider_list
import rater_lib
from retry import retry
@@ -46,20 +44,7 @@ from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib
from google import genai
class FilterInferenceWarnings(logging.Filter):
"""Filters out Vertex inference warning about non-text parts in response."""
def filter(self, record: logging.LogRecord) -> bool:
"""Filters out Vertex inference warning about non-text parts in response."""
if record.levelname != 'WARNING':
return True
message_identifier = record.getMessage()
return not message_identifier.startswith(
'Warning: there are non-text parts in the response:'
)
import utils
def run_tau_bench_rollouts(
@@ -494,26 +479,6 @@ def _get_datasets(
)
def reflection_inference_fn(model: str) -> Callable[[str], str]:
"""Returns an inference function on VertexAI based on provided model."""
client = genai.Client()
@retry(tries=3, delay=10, backoff=2)
def _fn(prompt):
return client.models.generate_content(
model=model,
contents=prompt,
config=types.GenerateContentConfig(
candidate_count=1,
thinking_config=types.ThinkingConfig(
include_thoughts=True, thinking_budget=-1
),
),
).text
return _fn
SEED_SYSTEM_INSTRUCTION = (
'you are a customer support agent helping customers resolve their '
'issues by using the right tools'
@@ -618,7 +583,7 @@ def run_gepa(
task_lm=None, # this must be None when a custom adapter is used
adapter=tau_bench_adapter,
max_metric_calls=config.max_metric_calls,
reflection_lm=reflection_inference_fn(config.reflection_model),
reflection_lm=utils.reflection_inference_fn(config.reflection_model),
reflection_minibatch_size=config.reflection_minibatch_size,
run_dir=output_dir,
)
@@ -98,6 +98,7 @@
"\n",
"import experiment as experiment_lib\n",
"from google.genai import types\n",
"import utils\n",
"\n",
"\n",
"# @markdown ### ☁️ Configure Vertex AI Access\n",
@@ -140,7 +141,7 @@
"\n",
"# Set a logging verbosity suited for this experiment. See\n",
"# https://github.com/google/adk-python/issues/1852 for context\n",
"types.logger.addFilter(experiment_lib.FilterInferenceWarnings())"
"types.logger.addFilter(utils.FilterInferenceWarnings())"
]
},
{
+3 -1
View File
@@ -26,6 +26,8 @@ from absl import flags
import experiment
from google.genai import types
import utils
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
None,
@@ -104,7 +106,7 @@ def main(argv: Sequence[str]) -> None:
for logger in loggers:
logger.setLevel(logging.WARNING)
types.logger.addFilter(experiment.FilterInferenceWarnings())
types.logger.addFilter(utils.FilterInferenceWarnings())
output_dir = os.path.join(
_OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f')
)
+56
View File
@@ -0,0 +1,56 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines utility for GEPA experiments."""
import logging
from typing import Callable
from google.genai import types
from retry import retry
from google import genai
class FilterInferenceWarnings(logging.Filter):
"""Filters out Vertex inference warning about non-text parts in response."""
def filter(self, record: logging.LogRecord) -> bool:
"""Filters out Vertex inference warning about non-text parts in response."""
if record.levelname != 'WARNING':
return True
message_identifier = record.getMessage()
return not message_identifier.startswith(
'Warning: there are non-text parts in the response:'
)
def reflection_inference_fn(model: str) -> Callable[[str], str]:
"""Returns an inference function on VertexAI based on provided model."""
client = genai.Client()
@retry(tries=3, delay=10, backoff=2)
def _fn(prompt):
return client.models.generate_content(
model=model,
contents=prompt,
config=types.GenerateContentConfig(
candidate_count=1,
thinking_config=types.ThinkingConfig(
include_thoughts=True, thinking_budget=-1
),
),
).text
return _fn
@@ -65,6 +65,7 @@
"#@title Install GEPA\n",
"!git clone https://github.com/google/adk-python.git\n",
"!pip install gepa --quiet\n",
"!pip install litellm --quiet\n",
"!pip install retry --quiet"
]
},
@@ -112,7 +113,7 @@
"import os\n",
"\n",
"from google.genai import types\n",
"import experiment as experiment_lib\n",
"import utils\n",
"\n",
"\n",
"# @markdown ### ☁️ Configure Vertex AI Access\n",
@@ -139,7 +140,7 @@
"for logger in loggers:\n",
" logger.setLevel(logging.WARNING)\n",
"\n",
"types.logger.addFilter(experiment_lib.FilterInferenceWarnings())"
"types.logger.addFilter(utils.FilterInferenceWarnings())"
]
},
{
@@ -179,7 +180,7 @@
"from google.adk.agents import base_agent\n",
"from google.adk.agents import llm_agent\n",
"\n",
"import tools\n",
"from voter_agent import tools\n",
"\n",
"\n",
"# @markdown ### 🧠 Configure our ADK LLM Agent\n",
@@ -368,7 +369,10 @@
" return [line.strip() for line in open(filename) if line.strip()]\n",
"\n",
"\n",
"voter_data = _read_prompts('prompts.txt')\n",
"_AGENT_DIR = 'adk-python/contributing/samples/gepa/voter_agent'\n",
"\n",
"\n",
"voter_data = _read_prompts(f'{_AGENT_DIR}/prompts.txt')\n",
"voter_data"
]
},
@@ -392,7 +396,8 @@
"execution_count": null,
"metadata": {
"id": "9bHh93RuKVMu",
"outputId": "489761d4-da39-43ca-cd08-225c44bb3027"
"outputId": "489761d4-da39-43ca-cd08-225c44bb3027",
"cellView": "form"
},
"outputs": [
{
@@ -714,7 +719,7 @@
" tool_declarations=TOOLS_DESCRIPTION,\n",
" developer_instructions='',\n",
" rubric=FILTER_RUBRIC,\n",
"\n",
" validation_template_path=f'{_AGENT_DIR}/rubric_validation_template.txt',\n",
")\n",
"\n",
"print(rater(EXAMPLE_TRACE))"
@@ -813,7 +818,7 @@
"source": [
"#@title Let's define an evaluation dataset from sample prompts\n",
"\n",
"eval_dataset = _read_prompts('eval_prompts.txt')\n",
"eval_dataset = _read_prompts(f'{_AGENT_DIR}/eval_prompts.txt')\n",
"eval_dataset"
]
},
@@ -2723,7 +2728,7 @@
" task_lm=None, # this must be None when a custom adapter is used\n",
" adapter=adapter,\n",
" max_metric_calls=MAX_METRIC_CALLS,\n",
" reflection_lm=experiment_lib.reflection_inference_fn(REFLECTION_MODEL_NAME),\n",
" reflection_lm=utils.reflection_inference_fn(REFLECTION_MODEL_NAME),\n",
" reflection_minibatch_size=MINI_BATCH_SIZE,\n",
")\n",
"list(enumerate(gepa_results.val_aggregate_scores))"
@@ -2955,9 +2960,6 @@
],
"metadata": {
"colab": {
"collapsed_sections": [
"rIFFNqYoXp6v"
],
"last_runtime": {
"build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook",
"kind": "private"