diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index 049d2bf7..3788a2e6 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -118,7 +118,7 @@ class AgentEvaluator: Args: agent_module: The path to python module that contains the definition of the agent. There is convention in place here, where the code is going to - look for 'root_agent' in the loaded module. + look for 'root_agent' or `get_agent_async` in the loaded module. eval_set: The eval set. criteria: Evauation criterias, a dictionary of metric names to their respective thresholds. This field is deprecated. @@ -144,7 +144,7 @@ class AgentEvaluator: if eval_config is None: raise ValueError("`eval_config` is required.") - agent_for_eval = AgentEvaluator._get_agent_for_eval( + agent_for_eval = await AgentEvaluator._get_agent_for_eval( module_name=agent_module, agent_name=agent_name ) eval_metrics = get_eval_metrics_from_config(eval_config) @@ -200,7 +200,7 @@ class AgentEvaluator: Args: agent_module: The path to python module that contains the definition of the agent. There is convention in place here, where the code is going to - look for 'root_agent' in the loaded module. + look for 'root_agent' or 'get_agent_async' in the loaded module. eval_dataset_file_path_or_dir: The eval data set. This can be either a string representing full path to the file containing eval dataset, or a directory that is recursively explored for all files that have a @@ -466,12 +466,26 @@ class AgentEvaluator: return "\n".join([str(t) for t in tool_calls]) @staticmethod - def _get_agent_for_eval( + async def _get_agent_for_eval( module_name: str, agent_name: Optional[str] = None ) -> BaseAgent: module_path = f"{module_name}" agent_module = importlib.import_module(module_path) - root_agent = agent_module.agent.root_agent + print(dir(agent_module)) + if hasattr(agent_module, "agent"): + if hasattr(agent_module.agent, "root_agent"): + root_agent = agent_module.agent.root_agent + elif hasattr(agent_module.agent, "get_agent_async"): + root_agent, _ = await agent_module.agent.get_agent_async() + else: + raise ValueError( + f"Module {module_name} does not have a root_agent or" + " get_agent_async method." + ) + else: + raise ValueError( + f"Module {module_name} does not have a member named `agent`." + ) agent_for_eval = root_agent if agent_name: diff --git a/tests/integration/fixture/hello_world_agent_async/__init__.py b/tests/integration/fixture/hello_world_agent_async/__init__.py new file mode 100644 index 00000000..c48963cd --- /dev/null +++ b/tests/integration/fixture/hello_world_agent_async/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/tests/integration/fixture/hello_world_agent_async/agent.py b/tests/integration/fixture/hello_world_agent_async/agent.py new file mode 100644 index 00000000..b105065c --- /dev/null +++ b/tests/integration/fixture/hello_world_agent_async/agent.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Hello world agent from agent 1.0 revised to be defined with get_agent_async +# instead of root_agent - https://colab.sandbox.google.com/drive/1Zq-nqmgK0nCERCv8jKIaoeTTgbNn6oSo?resourcekey=0-GYaz9pFT4wY8CI8Cvjy5GA#scrollTo=u3X3XwDOaCv9 +import contextlib +import random +from typing import Optional + +from google.adk import Agent +from google.adk.agents import llm_agent +from google.genai import types + + +def roll_die(sides: int) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + return random.randint(1, sides) + + +def check_prime(nums: list[int]) -> list[str]: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +async def get_agent_async() -> ( + tuple[llm_agent.LlmAgent, Optional[contextlib.AsyncExitStack]] +): + """Returns the root agent.""" + root_agent = Agent( + model='gemini-2.0-flash-001', + name='data_processing_agent', + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + The only things you do are roll dice for the user and discuss the outcomes. + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), + ) + return root_agent, None diff --git a/tests/integration/fixture/hello_world_agent_async/roll_die.test.json b/tests/integration/fixture/hello_world_agent_async/roll_die.test.json new file mode 100644 index 00000000..7e787d40 --- /dev/null +++ b/tests/integration/fixture/hello_world_agent_async/roll_die.test.json @@ -0,0 +1,55 @@ +{ + "eval_set_id": "56540925-a5ff-49fe-a4e1-589fe78066f2", + "name": "56540925-a5ff-49fe-a4e1-589fe78066f2", + "description": null, + "eval_cases": [ + { + "eval_id": "tests/integration/fixture/hello_world_agent_async/roll_die.test.json", + "conversation": [ + { + "invocation_id": "b01f67f0-9f23-44d6-bbe4-36ea235cb9fb", + "user_content": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "code_execution_result": null, + "executable_code": null, + "file_data": null, + "function_call": null, + "function_response": null, + "inline_data": null, + "text": "Hi who are you?" + } + ], + "role": "user" + }, + "final_response": { + "parts": [ + { + "video_metadata": null, + "thought": null, + "code_execution_result": null, + "executable_code": null, + "file_data": null, + "function_call": null, + "function_response": null, + "inline_data": null, + "text": "I am a data processing agent. I can roll dice and check if the results are prime numbers. What would you like me to do? \n" + } + ], + "role": "model" + }, + "intermediate_data": { + "tool_uses": [], + "intermediate_responses": [] + }, + "creation_timestamp": 1747341775.8937013 + } + ], + "session_input": null, + "creation_timestamp": 1747341775.8937826 + } + ], + "creation_timestamp": 1747341775.8937957 +} \ No newline at end of file diff --git a/tests/integration/fixture/hello_world_agent_async/test_config.json b/tests/integration/fixture/hello_world_agent_async/test_config.json new file mode 100644 index 00000000..c7fba6a4 --- /dev/null +++ b/tests/integration/fixture/hello_world_agent_async/test_config.json @@ -0,0 +1,6 @@ +{ + "criteria": { + "tool_trajectory_avg_score": 1.0, + "response_match_score": 0.5 + } +} diff --git a/tests/integration/test_single_agent.py b/tests/integration/test_single_agent.py index 183005ed..49f3b7ba 100644 --- a/tests/integration/test_single_agent.py +++ b/tests/integration/test_single_agent.py @@ -23,3 +23,14 @@ async def test_eval_agent(): eval_dataset_file_path_or_dir="tests/integration/fixture/home_automation_agent/simple_test.test.json", num_runs=4, ) + + +@pytest.mark.asyncio +async def test_eval_agent_async(): + await AgentEvaluator.evaluate( + agent_module="tests.integration.fixture.hello_world_agent_async", + eval_dataset_file_path_or_dir=( + "tests/integration/fixture/hello_world_agent_async/roll_die.test.json" + ), + num_runs=4, + )