You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Some small infra fixes to the gepa demo colab
PiperOrigin-RevId: 829240716
This commit is contained in:
committed by
Copybara-Service
parent
d118479ccf
commit
a0cf97eba2
@@ -26,13 +26,11 @@ import os
|
||||
import random
|
||||
import traceback
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import TypedDict
|
||||
|
||||
import gepa
|
||||
from gepa.core.adapter import EvaluationBatch
|
||||
from gepa.core.adapter import GEPAAdapter
|
||||
from google.genai import types
|
||||
from litellm import provider_list
|
||||
import rater_lib
|
||||
from retry import retry
|
||||
@@ -46,20 +44,7 @@ from tau_bench.types import EnvRunResult
|
||||
from tau_bench.types import RunConfig
|
||||
import tau_bench_agent as tau_bench_agent_lib
|
||||
|
||||
from google import genai
|
||||
|
||||
|
||||
class FilterInferenceWarnings(logging.Filter):
|
||||
"""Filters out Vertex inference warning about non-text parts in response."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""Filters out Vertex inference warning about non-text parts in response."""
|
||||
if record.levelname != 'WARNING':
|
||||
return True
|
||||
message_identifier = record.getMessage()
|
||||
return not message_identifier.startswith(
|
||||
'Warning: there are non-text parts in the response:'
|
||||
)
|
||||
import utils
|
||||
|
||||
|
||||
def run_tau_bench_rollouts(
|
||||
@@ -494,26 +479,6 @@ def _get_datasets(
|
||||
)
|
||||
|
||||
|
||||
def reflection_inference_fn(model: str) -> Callable[[str], str]:
|
||||
"""Returns an inference function on VertexAI based on provided model."""
|
||||
client = genai.Client()
|
||||
|
||||
@retry(tries=3, delay=10, backoff=2)
|
||||
def _fn(prompt):
|
||||
return client.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
candidate_count=1,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
include_thoughts=True, thinking_budget=-1
|
||||
),
|
||||
),
|
||||
).text
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
SEED_SYSTEM_INSTRUCTION = (
|
||||
'you are a customer support agent helping customers resolve their '
|
||||
'issues by using the right tools'
|
||||
@@ -618,7 +583,7 @@ def run_gepa(
|
||||
task_lm=None, # this must be None when a custom adapter is used
|
||||
adapter=tau_bench_adapter,
|
||||
max_metric_calls=config.max_metric_calls,
|
||||
reflection_lm=reflection_inference_fn(config.reflection_model),
|
||||
reflection_lm=utils.reflection_inference_fn(config.reflection_model),
|
||||
reflection_minibatch_size=config.reflection_minibatch_size,
|
||||
run_dir=output_dir,
|
||||
)
|
||||
|
||||
@@ -98,6 +98,7 @@
|
||||
"\n",
|
||||
"import experiment as experiment_lib\n",
|
||||
"from google.genai import types\n",
|
||||
"import utils\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# @markdown ### ☁️ Configure Vertex AI Access\n",
|
||||
@@ -140,7 +141,7 @@
|
||||
"\n",
|
||||
"# Set a logging verbosity suited for this experiment. See\n",
|
||||
"# https://github.com/google/adk-python/issues/1852 for context\n",
|
||||
"types.logger.addFilter(experiment_lib.FilterInferenceWarnings())"
|
||||
"types.logger.addFilter(utils.FilterInferenceWarnings())"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -26,6 +26,8 @@ from absl import flags
|
||||
import experiment
|
||||
from google.genai import types
|
||||
|
||||
import utils
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir',
|
||||
None,
|
||||
@@ -104,7 +106,7 @@ def main(argv: Sequence[str]) -> None:
|
||||
for logger in loggers:
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
types.logger.addFilter(experiment.FilterInferenceWarnings())
|
||||
types.logger.addFilter(utils.FilterInferenceWarnings())
|
||||
output_dir = os.path.join(
|
||||
_OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f')
|
||||
)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines utility for GEPA experiments."""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from google.genai import types
|
||||
from retry import retry
|
||||
|
||||
from google import genai
|
||||
|
||||
|
||||
class FilterInferenceWarnings(logging.Filter):
|
||||
"""Filters out Vertex inference warning about non-text parts in response."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""Filters out Vertex inference warning about non-text parts in response."""
|
||||
if record.levelname != 'WARNING':
|
||||
return True
|
||||
message_identifier = record.getMessage()
|
||||
return not message_identifier.startswith(
|
||||
'Warning: there are non-text parts in the response:'
|
||||
)
|
||||
|
||||
|
||||
def reflection_inference_fn(model: str) -> Callable[[str], str]:
|
||||
"""Returns an inference function on VertexAI based on provided model."""
|
||||
client = genai.Client()
|
||||
|
||||
@retry(tries=3, delay=10, backoff=2)
|
||||
def _fn(prompt):
|
||||
return client.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
candidate_count=1,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
include_thoughts=True, thinking_budget=-1
|
||||
),
|
||||
),
|
||||
).text
|
||||
|
||||
return _fn
|
||||
@@ -65,6 +65,7 @@
|
||||
"#@title Install GEPA\n",
|
||||
"!git clone https://github.com/google/adk-python.git\n",
|
||||
"!pip install gepa --quiet\n",
|
||||
"!pip install litellm --quiet\n",
|
||||
"!pip install retry --quiet"
|
||||
]
|
||||
},
|
||||
@@ -112,7 +113,7 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"from google.genai import types\n",
|
||||
"import experiment as experiment_lib\n",
|
||||
"import utils\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# @markdown ### ☁️ Configure Vertex AI Access\n",
|
||||
@@ -139,7 +140,7 @@
|
||||
"for logger in loggers:\n",
|
||||
" logger.setLevel(logging.WARNING)\n",
|
||||
"\n",
|
||||
"types.logger.addFilter(experiment_lib.FilterInferenceWarnings())"
|
||||
"types.logger.addFilter(utils.FilterInferenceWarnings())"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -179,7 +180,7 @@
|
||||
"from google.adk.agents import base_agent\n",
|
||||
"from google.adk.agents import llm_agent\n",
|
||||
"\n",
|
||||
"import tools\n",
|
||||
"from voter_agent import tools\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# @markdown ### 🧠 Configure our ADK LLM Agent\n",
|
||||
@@ -368,7 +369,10 @@
|
||||
" return [line.strip() for line in open(filename) if line.strip()]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"voter_data = _read_prompts('prompts.txt')\n",
|
||||
"_AGENT_DIR = 'adk-python/contributing/samples/gepa/voter_agent'\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"voter_data = _read_prompts(f'{_AGENT_DIR}/prompts.txt')\n",
|
||||
"voter_data"
|
||||
]
|
||||
},
|
||||
@@ -392,7 +396,8 @@
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9bHh93RuKVMu",
|
||||
"outputId": "489761d4-da39-43ca-cd08-225c44bb3027"
|
||||
"outputId": "489761d4-da39-43ca-cd08-225c44bb3027",
|
||||
"cellView": "form"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -714,7 +719,7 @@
|
||||
" tool_declarations=TOOLS_DESCRIPTION,\n",
|
||||
" developer_instructions='',\n",
|
||||
" rubric=FILTER_RUBRIC,\n",
|
||||
"\n",
|
||||
" validation_template_path=f'{_AGENT_DIR}/rubric_validation_template.txt',\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(rater(EXAMPLE_TRACE))"
|
||||
@@ -813,7 +818,7 @@
|
||||
"source": [
|
||||
"#@title Let's define an evaluation dataset from sample prompts\n",
|
||||
"\n",
|
||||
"eval_dataset = _read_prompts('eval_prompts.txt')\n",
|
||||
"eval_dataset = _read_prompts(f'{_AGENT_DIR}/eval_prompts.txt')\n",
|
||||
"eval_dataset"
|
||||
]
|
||||
},
|
||||
@@ -2723,7 +2728,7 @@
|
||||
" task_lm=None, # this must be None when a custom adapter is used\n",
|
||||
" adapter=adapter,\n",
|
||||
" max_metric_calls=MAX_METRIC_CALLS,\n",
|
||||
" reflection_lm=experiment_lib.reflection_inference_fn(REFLECTION_MODEL_NAME),\n",
|
||||
" reflection_lm=utils.reflection_inference_fn(REFLECTION_MODEL_NAME),\n",
|
||||
" reflection_minibatch_size=MINI_BATCH_SIZE,\n",
|
||||
")\n",
|
||||
"list(enumerate(gepa_results.val_aggregate_scores))"
|
||||
@@ -2955,9 +2960,6 @@
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [
|
||||
"rIFFNqYoXp6v"
|
||||
],
|
||||
"last_runtime": {
|
||||
"build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook",
|
||||
"kind": "private"
|
||||
|
||||
Reference in New Issue
Block a user