You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Update AgentEvaluator to handle async ADK agent definitions
AgentEvaluator should recognize root_agent and get_agent_async as valid structures for ADK agent definitions. PiperOrigin-RevId: 819976635
This commit is contained in:
committed by
Copybara-Service
parent
a17f3b2e6d
commit
86097afe49
@@ -118,7 +118,7 @@ class AgentEvaluator:
|
||||
Args:
|
||||
agent_module: The path to python module that contains the definition of
|
||||
the agent. There is convention in place here, where the code is going to
|
||||
look for 'root_agent' in the loaded module.
|
||||
look for 'root_agent' or `get_agent_async` in the loaded module.
|
||||
eval_set: The eval set.
|
||||
criteria: Evauation criterias, a dictionary of metric names to their
|
||||
respective thresholds. This field is deprecated.
|
||||
@@ -144,7 +144,7 @@ class AgentEvaluator:
|
||||
if eval_config is None:
|
||||
raise ValueError("`eval_config` is required.")
|
||||
|
||||
agent_for_eval = AgentEvaluator._get_agent_for_eval(
|
||||
agent_for_eval = await AgentEvaluator._get_agent_for_eval(
|
||||
module_name=agent_module, agent_name=agent_name
|
||||
)
|
||||
eval_metrics = get_eval_metrics_from_config(eval_config)
|
||||
@@ -200,7 +200,7 @@ class AgentEvaluator:
|
||||
Args:
|
||||
agent_module: The path to python module that contains the definition of
|
||||
the agent. There is convention in place here, where the code is going to
|
||||
look for 'root_agent' in the loaded module.
|
||||
look for 'root_agent' or 'get_agent_async' in the loaded module.
|
||||
eval_dataset_file_path_or_dir: The eval data set. This can be either a
|
||||
string representing full path to the file containing eval dataset, or a
|
||||
directory that is recursively explored for all files that have a
|
||||
@@ -466,12 +466,26 @@ class AgentEvaluator:
|
||||
return "\n".join([str(t) for t in tool_calls])
|
||||
|
||||
@staticmethod
|
||||
def _get_agent_for_eval(
|
||||
async def _get_agent_for_eval(
|
||||
module_name: str, agent_name: Optional[str] = None
|
||||
) -> BaseAgent:
|
||||
module_path = f"{module_name}"
|
||||
agent_module = importlib.import_module(module_path)
|
||||
root_agent = agent_module.agent.root_agent
|
||||
print(dir(agent_module))
|
||||
if hasattr(agent_module, "agent"):
|
||||
if hasattr(agent_module.agent, "root_agent"):
|
||||
root_agent = agent_module.agent.root_agent
|
||||
elif hasattr(agent_module.agent, "get_agent_async"):
|
||||
root_agent, _ = await agent_module.agent.get_agent_async()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Module {module_name} does not have a root_agent or"
|
||||
" get_agent_async method."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Module {module_name} does not have a member named `agent`."
|
||||
)
|
||||
|
||||
agent_for_eval = root_agent
|
||||
if agent_name:
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Hello world agent from agent 1.0 revised to be defined with get_agent_async
|
||||
# instead of root_agent - https://colab.sandbox.google.com/drive/1Zq-nqmgK0nCERCv8jKIaoeTTgbNn6oSo?resourcekey=0-GYaz9pFT4wY8CI8Cvjy5GA#scrollTo=u3X3XwDOaCv9
|
||||
import contextlib
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.agents import llm_agent
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def roll_die(sides: int) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
return random.randint(1, sides)
|
||||
|
||||
|
||||
def check_prime(nums: list[int]) -> list[str]:
|
||||
"""Check if a given list of numbers are prime.
|
||||
|
||||
Args:
|
||||
nums: The list of numbers to check.
|
||||
|
||||
Returns:
|
||||
A str indicating which number is prime.
|
||||
"""
|
||||
primes = set()
|
||||
for number in nums:
|
||||
number = int(number)
|
||||
if number <= 1:
|
||||
continue
|
||||
is_prime = True
|
||||
for i in range(2, int(number**0.5) + 1):
|
||||
if number % i == 0:
|
||||
is_prime = False
|
||||
break
|
||||
if is_prime:
|
||||
primes.add(number)
|
||||
return (
|
||||
'No prime numbers found.'
|
||||
if not primes
|
||||
else f"{', '.join(str(num) for num in primes)} are prime numbers."
|
||||
)
|
||||
|
||||
|
||||
async def get_agent_async() -> (
|
||||
tuple[llm_agent.LlmAgent, Optional[contextlib.AsyncExitStack]]
|
||||
):
|
||||
"""Returns the root agent."""
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash-001',
|
||||
name='data_processing_agent',
|
||||
instruction="""
|
||||
You roll dice and answer questions about the outcome of the dice rolls.
|
||||
You can roll dice of different sizes.
|
||||
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
|
||||
The only things you do are roll dice for the user and discuss the outcomes.
|
||||
It is ok to discuss previous dice roles, and comment on the dice rolls.
|
||||
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
|
||||
You should never roll a die on your own.
|
||||
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
|
||||
You should not check prime numbers before calling the tool.
|
||||
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
|
||||
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
|
||||
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
|
||||
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
|
||||
3. When you respond, you must include the roll_die result from step 1.
|
||||
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
|
||||
You should not rely on the previous history on prime results.
|
||||
""",
|
||||
tools=[
|
||||
roll_die,
|
||||
check_prime,
|
||||
],
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
safety_settings=[
|
||||
types.SafetySetting( # avoid false alarm about rolling dice.
|
||||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=types.HarmBlockThreshold.OFF,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
return root_agent, None
|
||||
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"eval_set_id": "56540925-a5ff-49fe-a4e1-589fe78066f2",
|
||||
"name": "56540925-a5ff-49fe-a4e1-589fe78066f2",
|
||||
"description": null,
|
||||
"eval_cases": [
|
||||
{
|
||||
"eval_id": "tests/integration/fixture/hello_world_agent_async/roll_die.test.json",
|
||||
"conversation": [
|
||||
{
|
||||
"invocation_id": "b01f67f0-9f23-44d6-bbe4-36ea235cb9fb",
|
||||
"user_content": {
|
||||
"parts": [
|
||||
{
|
||||
"video_metadata": null,
|
||||
"thought": null,
|
||||
"code_execution_result": null,
|
||||
"executable_code": null,
|
||||
"file_data": null,
|
||||
"function_call": null,
|
||||
"function_response": null,
|
||||
"inline_data": null,
|
||||
"text": "Hi who are you?"
|
||||
}
|
||||
],
|
||||
"role": "user"
|
||||
},
|
||||
"final_response": {
|
||||
"parts": [
|
||||
{
|
||||
"video_metadata": null,
|
||||
"thought": null,
|
||||
"code_execution_result": null,
|
||||
"executable_code": null,
|
||||
"file_data": null,
|
||||
"function_call": null,
|
||||
"function_response": null,
|
||||
"inline_data": null,
|
||||
"text": "I am a data processing agent. I can roll dice and check if the results are prime numbers. What would you like me to do? \n"
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"intermediate_data": {
|
||||
"tool_uses": [],
|
||||
"intermediate_responses": []
|
||||
},
|
||||
"creation_timestamp": 1747341775.8937013
|
||||
}
|
||||
],
|
||||
"session_input": null,
|
||||
"creation_timestamp": 1747341775.8937826
|
||||
}
|
||||
],
|
||||
"creation_timestamp": 1747341775.8937957
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"criteria": {
|
||||
"tool_trajectory_avg_score": 1.0,
|
||||
"response_match_score": 0.5
|
||||
}
|
||||
}
|
||||
@@ -23,3 +23,14 @@ async def test_eval_agent():
|
||||
eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json",
|
||||
num_runs=4,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_agent_async():
|
||||
await AgentEvaluator.evaluate(
|
||||
agent_module="tests.integration.fixture.hello_world_agent_async",
|
||||
eval_dataset_file_path_or_dir=(
|
||||
"tests/integration/fixture/hello_world_agent_async/roll_die.test.json"
|
||||
),
|
||||
num_runs=4,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user