diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index e8868909..2f5d03a7 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -26,13 +26,11 @@ import os import random import traceback from typing import Any -from typing import Callable from typing import TypedDict import gepa from gepa.core.adapter import EvaluationBatch from gepa.core.adapter import GEPAAdapter -from google.genai import types from litellm import provider_list import rater_lib from retry import retry @@ -46,20 +44,7 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib -from google import genai - - -class FilterInferenceWarnings(logging.Filter): - """Filters out Vertex inference warning about non-text parts in response.""" - - def filter(self, record: logging.LogRecord) -> bool: - """Filters out Vertex inference warning about non-text parts in response.""" - if record.levelname != 'WARNING': - return True - message_identifier = record.getMessage() - return not message_identifier.startswith( - 'Warning: there are non-text parts in the response:' - ) +import utils def run_tau_bench_rollouts( @@ -494,26 +479,6 @@ def _get_datasets( ) -def reflection_inference_fn(model: str) -> Callable[[str], str]: - """Returns an inference function on VertexAI based on provided model.""" - client = genai.Client() - - @retry(tries=3, delay=10, backoff=2) - def _fn(prompt): - return client.models.generate_content( - model=model, - contents=prompt, - config=types.GenerateContentConfig( - candidate_count=1, - thinking_config=types.ThinkingConfig( - include_thoughts=True, thinking_budget=-1 - ), - ), - ).text - - return _fn - - SEED_SYSTEM_INSTRUCTION = ( 'you are a customer support agent helping customers resolve their ' 'issues by using the right tools' @@ -618,7 +583,7 @@ def run_gepa( task_lm=None, # this must be None when a custom adapter is used adapter=tau_bench_adapter, max_metric_calls=config.max_metric_calls, - reflection_lm=reflection_inference_fn(config.reflection_model), + reflection_lm=utils.reflection_inference_fn(config.reflection_model), reflection_minibatch_size=config.reflection_minibatch_size, run_dir=output_dir, ) diff --git a/contributing/samples/gepa/gepa_tau_bench.ipynb b/contributing/samples/gepa/gepa_tau_bench.ipynb index 9617d3ae..9ca4f318 100644 --- a/contributing/samples/gepa/gepa_tau_bench.ipynb +++ b/contributing/samples/gepa/gepa_tau_bench.ipynb @@ -98,6 +98,7 @@ "\n", "import experiment as experiment_lib\n", "from google.genai import types\n", + "import utils\n", "\n", "\n", "# @markdown ### ☁️ Configure Vertex AI Access\n", @@ -140,7 +141,7 @@ "\n", "# Set a logging verbosity suited for this experiment. See\n", "# https://github.com/google/adk-python/issues/1852 for context\n", - "types.logger.addFilter(experiment_lib.FilterInferenceWarnings())" + "types.logger.addFilter(utils.FilterInferenceWarnings())" ] }, { diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index 1da2efa1..cfd850b3 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -26,6 +26,8 @@ from absl import flags import experiment from google.genai import types +import utils + _OUTPUT_DIR = flags.DEFINE_string( 'output_dir', None, @@ -104,7 +106,7 @@ def main(argv: Sequence[str]) -> None: for logger in loggers: logger.setLevel(logging.WARNING) - types.logger.addFilter(experiment.FilterInferenceWarnings()) + types.logger.addFilter(utils.FilterInferenceWarnings()) output_dir = os.path.join( _OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f') ) diff --git a/contributing/samples/gepa/utils.py b/contributing/samples/gepa/utils.py new file mode 100644 index 00000000..0763d280 --- /dev/null +++ b/contributing/samples/gepa/utils.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines utility for GEPA experiments.""" + +import logging +from typing import Callable + +from google.genai import types +from retry import retry + +from google import genai + + +class FilterInferenceWarnings(logging.Filter): + """Filters out Vertex inference warning about non-text parts in response.""" + + def filter(self, record: logging.LogRecord) -> bool: + """Filters out Vertex inference warning about non-text parts in response.""" + if record.levelname != 'WARNING': + return True + message_identifier = record.getMessage() + return not message_identifier.startswith( + 'Warning: there are non-text parts in the response:' + ) + + +def reflection_inference_fn(model: str) -> Callable[[str], str]: + """Returns an inference function on VertexAI based on provided model.""" + client = genai.Client() + + @retry(tries=3, delay=10, backoff=2) + def _fn(prompt): + return client.models.generate_content( + model=model, + contents=prompt, + config=types.GenerateContentConfig( + candidate_count=1, + thinking_config=types.ThinkingConfig( + include_thoughts=True, thinking_budget=-1 + ), + ), + ).text + + return _fn diff --git a/contributing/samples/gepa/voter_agent/gepa.ipynb b/contributing/samples/gepa/voter_agent/gepa.ipynb index 9c9868f5..d664de8f 100644 --- a/contributing/samples/gepa/voter_agent/gepa.ipynb +++ b/contributing/samples/gepa/voter_agent/gepa.ipynb @@ -65,6 +65,7 @@ "#@title Install GEPA\n", "!git clone https://github.com/google/adk-python.git\n", "!pip install gepa --quiet\n", + "!pip install litellm --quiet\n", "!pip install retry --quiet" ] }, @@ -112,7 +113,7 @@ "import os\n", "\n", "from google.genai import types\n", - "import experiment as experiment_lib\n", + "import utils\n", "\n", "\n", "# @markdown ### ☁️ Configure Vertex AI Access\n", @@ -139,7 +140,7 @@ "for logger in loggers:\n", " logger.setLevel(logging.WARNING)\n", "\n", - "types.logger.addFilter(experiment_lib.FilterInferenceWarnings())" + "types.logger.addFilter(utils.FilterInferenceWarnings())" ] }, { @@ -179,7 +180,7 @@ "from google.adk.agents import base_agent\n", "from google.adk.agents import llm_agent\n", "\n", - "import tools\n", + "from voter_agent import tools\n", "\n", "\n", "# @markdown ### 🧠 Configure our ADK LLM Agent\n", @@ -368,7 +369,10 @@ " return [line.strip() for line in open(filename) if line.strip()]\n", "\n", "\n", - "voter_data = _read_prompts('prompts.txt')\n", + "_AGENT_DIR = 'adk-python/contributing/samples/gepa/voter_agent'\n", + "\n", + "\n", + "voter_data = _read_prompts(f'{_AGENT_DIR}/prompts.txt')\n", "voter_data" ] }, @@ -392,7 +396,8 @@ "execution_count": null, "metadata": { "id": "9bHh93RuKVMu", - "outputId": "489761d4-da39-43ca-cd08-225c44bb3027" + "outputId": "489761d4-da39-43ca-cd08-225c44bb3027", + "cellView": "form" }, "outputs": [ { @@ -714,7 +719,7 @@ " tool_declarations=TOOLS_DESCRIPTION,\n", " developer_instructions='',\n", " rubric=FILTER_RUBRIC,\n", - "\n", + " validation_template_path=f'{_AGENT_DIR}/rubric_validation_template.txt',\n", ")\n", "\n", "print(rater(EXAMPLE_TRACE))" @@ -813,7 +818,7 @@ "source": [ "#@title Let's define an evaluation dataset from sample prompts\n", "\n", - "eval_dataset = _read_prompts('eval_prompts.txt')\n", + "eval_dataset = _read_prompts(f'{_AGENT_DIR}/eval_prompts.txt')\n", "eval_dataset" ] }, @@ -2723,7 +2728,7 @@ " task_lm=None, # this must be None when a custom adapter is used\n", " adapter=adapter,\n", " max_metric_calls=MAX_METRIC_CALLS,\n", - " reflection_lm=experiment_lib.reflection_inference_fn(REFLECTION_MODEL_NAME),\n", + " reflection_lm=utils.reflection_inference_fn(REFLECTION_MODEL_NAME),\n", " reflection_minibatch_size=MINI_BATCH_SIZE,\n", ")\n", "list(enumerate(gepa_results.val_aggregate_scores))" @@ -2955,9 +2960,6 @@ ], "metadata": { "colab": { - "collapsed_sections": [ - "rIFFNqYoXp6v" - ], "last_runtime": { "build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook", "kind": "private"