feat: Upgrade ADK stack to use App instead in addition to root_agent

The convention:
- If some fields(like plugin) are defined both at root_agent and app, then a error will be raised.
- app code should be located within agent.py.
- an instance named app should be created

PiperOrigin-RevId: 803155804
This commit is contained in:
Hangfei Lin
2025-09-04 13:37:15 -07:00
committed by Copybara-Service
parent 14484065c6
commit 4df79dd5c9
11 changed files with 476 additions and 36 deletions
+15
View File
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
+145
View File
@@ -0,0 +1,145 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.apps import App
from google.adk.models.llm_request import LlmRequest
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if not 'rolls' in tool_context.state:
tool_context.state['rolls'] = []
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
return result
async def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
'No prime numbers found.'
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
root_agent = Agent(
model='gemini-2.0-flash',
name='hello_world_agent',
description=(
'hello world agent that can roll a dice of 8 sides and check prime'
' numbers.'
),
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
It is ok to discuss previous dice roles, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
# planner=BuiltInPlanner(
# thinking_config=types.ThinkingConfig(
# include_thoughts=True,
# ),
# ),
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)
class CountInvocationPlugin(BasePlugin):
"""A custom plugin that counts agent and tool invocations."""
def __init__(self) -> None:
"""Initialize the plugin with counters."""
super().__init__(name='count_invocation')
self.agent_count: int = 0
self.tool_count: int = 0
self.llm_request_count: int = 0
async def before_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> None:
"""Count agent runs."""
self.agent_count += 1
print(f'[Plugin] Agent run count: {self.agent_count}')
async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> None:
"""Count LLM requests."""
self.llm_request_count += 1
print(f'[Plugin] LLM request count: {self.llm_request_count}')
app = App(
name='hello_world_app',
root_agent=root_agent,
plugins=[CountInvocationPlugin()],
)
+103
View File
@@ -0,0 +1,103 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import agent
from dotenv import load_dotenv
from google.adk.agents.run_config import RunConfig
from google.adk.cli.utils import logs
from google.adk.runners import InMemoryRunner
from google.adk.sessions.session import Session
from google.genai import types
load_dotenv(override=True)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
runner = InMemoryRunner(
agent=agent.root_agent,
app_name=app_name,
)
session_11 = await runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=str.encode(new_message), mime_type='text/plain'
)
],
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
async def check_rolls_in_state(rolls_size: int):
session = await runner.session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
assert len(session.state['rolls']) == rolls_size
for roll in session.state['rolls']:
assert roll > 0 and roll <= 100
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(session_11, 'Roll a die with 100 sides')
await check_rolls_in_state(1)
await run_prompt(session_11, 'Roll a die again with 100 sides.')
await check_rolls_in_state(2)
await run_prompt(session_11, 'What numbers did I got?')
await run_prompt_bytes(session_11, 'Hi bytes')
print(
await runner.artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
)
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
asyncio.run(main())
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .app import App
__all__ = [
'App',
]
+52
View File
@@ -0,0 +1,52 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import ABC
from typing import Optional
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from ..agents.base_agent import BaseAgent
from ..plugins.base_plugin import BasePlugin
from ..utils.feature_decorator import experimental
@experimental
class App(BaseModel):
"""Represents an LLM-backed agentic application.
An `App` is the top-level container for an agentic system powered by LLMs.
It manages a root agent (`root_agent`), which serves as the root of an agent
tree, enabling coordination and communication across all agents in the
hierarchy.
The `plugins` are application-wide components that provide shared capabilities
and services to the entire system.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)
name: str
"""The name of the application."""
root_agent: BaseAgent
"""The root agent in the application. One app can only have one root agent."""
plugins: list[BasePlugin] = Field(default_factory=list)
"""The plugins in the application."""
+16 -6
View File
@@ -50,10 +50,12 @@ from typing_extensions import override
from watchdog.observers import Observer
from . import agent_graph
from ..agents.base_agent import BaseAgent
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.run_config import RunConfig
from ..agents.run_config import StreamingMode
from ..apps.app import App
from ..artifacts.base_artifact_service import BaseArtifactService
from ..auth.credential_service.base_credential_service import BaseCredentialService
from ..errors.not_found_error import NotFoundError
@@ -322,10 +324,17 @@ class AdkWebServer:
envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir)
if app_name in self.runner_dict:
return self.runner_dict[app_name]
root_agent = self.agent_loader.load_agent(app_name)
agent_or_app = self.agent_loader.load_agent(app_name)
agentic_app = None
if isinstance(agent_or_app, BaseAgent):
agentic_app = App(
name=app_name,
root_agent=agent_or_app,
)
else:
agentic_app = agent_or_app
runner = Runner(
app_name=app_name,
agent=root_agent,
app=agentic_app,
artifact_service=self.artifact_service,
session_service=self.session_service,
memory_service=self.memory_service,
@@ -624,9 +633,10 @@ class AdkWebServer:
invocations = evals.convert_session_to_eval_invocations(session)
# Populate the session with initial session state.
initial_session_state = create_empty_state(
self.agent_loader.load_agent(app_name)
)
agent_or_app = self.agent_loader.load_agent(app_name)
if isinstance(agent_or_app, App):
agent_or_app = agent_or_app.root_agent
initial_session_state = create_empty_state(agent_or_app)
new_eval_case = EvalCase(
eval_id=req.eval_id,
+18 -7
View File
@@ -16,12 +16,15 @@ from __future__ import annotations
from datetime import datetime
from typing import Optional
from typing import Union
import click
from google.genai import types
from pydantic import BaseModel
from ..agents.base_agent import BaseAgent
from ..agents.llm_agent import LlmAgent
from ..apps.app import App
from ..artifacts.base_artifact_service import BaseArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..auth.credential_service.base_credential_service import BaseCredentialService
@@ -43,15 +46,19 @@ class InputFile(BaseModel):
async def run_input_file(
app_name: str,
user_id: str,
root_agent: LlmAgent,
agent_or_app: Union[LlmAgent, App],
artifact_service: BaseArtifactService,
session_service: BaseSessionService,
credential_service: BaseCredentialService,
input_path: str,
) -> Session:
app = (
agent_or_app
if isinstance(agent_or_app, App)
else App(name=app_name, root_agent=agent_or_app)
)
runner = Runner(
app_name=app_name,
agent=root_agent,
app=app,
artifact_service=artifact_service,
session_service=session_service,
credential_service=credential_service,
@@ -79,15 +86,19 @@ async def run_input_file(
async def run_interactively(
root_agent: LlmAgent,
root_agent_or_app: Union[LlmAgent, App],
artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService,
credential_service: BaseCredentialService,
) -> None:
app = (
root_agent_or_app
if isinstance(root_agent_or_app, App)
else App(name=session.app_name, root_agent=root_agent_or_app)
)
runner = Runner(
app_name=session.app_name,
agent=root_agent,
app=app,
artifact_service=artifact_service,
session_service=session_service,
credential_service=credential_service,
@@ -154,7 +165,7 @@ async def run_cli(
session = await run_input_file(
app_name=agent_folder_name,
user_id=user_id,
root_agent=root_agent,
agent_or_app=root_agent,
artifact_service=artifact_service,
session_service=session_service,
credential_service=credential_service,
+26 -10
View File
@@ -20,6 +20,7 @@ import os
from pathlib import Path
import sys
from typing import Optional
from typing import Union
from pydantic import ValidationError
from typing_extensions import override
@@ -27,6 +28,7 @@ from typing_extensions import override
from . import envs
from ...agents import config_agent_utils
from ...agents.base_agent import BaseAgent
from ...apps.app import App
from ...utils.feature_decorator import experimental
from .base_agent_loader import BaseAgentLoader
@@ -55,19 +57,25 @@ class AgentLoader(BaseAgentLoader):
def __init__(self, agents_dir: str):
self.agents_dir = agents_dir.rstrip("/")
self._original_sys_path = None
self._agent_cache: dict[str, BaseAgent] = {}
self._agent_cache: dict[str, Union[BaseAgent, App]] = {}
def _load_from_module_or_package(
self, agent_name: str
) -> Optional[BaseAgent]:
) -> Optional[Union[BaseAgent, App]]:
# Load for case: Import "{agent_name}" (as a package or module)
# Covers structures:
# a) agents_dir/{agent_name}.py (with root_agent in the module)
# b) agents_dir/{agent_name}/__init__.py (with root_agent in the package)
try:
module_candidate = importlib.import_module(agent_name)
# Check for "app" first, then "root_agent"
if hasattr(module_candidate, "app") and isinstance(
module_candidate.app, App
):
logger.debug("Found app in %s", agent_name)
return module_candidate.app
# Check for "root_agent" directly in "{agent_name}" module/package
if hasattr(module_candidate, "root_agent"):
elif hasattr(module_candidate, "root_agent"):
logger.debug("Found root_agent directly in %s", agent_name)
if isinstance(module_candidate.root_agent, BaseAgent):
return module_candidate.root_agent
@@ -101,12 +109,20 @@ class AgentLoader(BaseAgentLoader):
return None
def _load_from_submodule(self, agent_name: str) -> Optional[BaseAgent]:
def _load_from_submodule(
self, agent_name: str
) -> Optional[Union[BaseAgent], App]:
# Load for case: Import "{agent_name}.agent" and look for "root_agent"
# Covers structure: agents_dir/{agent_name}/agent.py (with root_agent defined in the module)
try:
module_candidate = importlib.import_module(f"{agent_name}.agent")
if hasattr(module_candidate, "root_agent"):
# Check for "app" first, then "root_agent"
if hasattr(module_candidate, "app") and isinstance(
module_candidate.app, App
):
logger.debug("Found app in %s.agent", agent_name)
return module_candidate.app
elif hasattr(module_candidate, "root_agent"):
logger.info("Found root_agent in %s.agent", agent_name)
if isinstance(module_candidate.root_agent, BaseAgent):
return module_candidate.root_agent
@@ -168,7 +184,7 @@ class AgentLoader(BaseAgentLoader):
) + e.args[1:]
raise e
def _perform_load(self, agent_name: str) -> BaseAgent:
def _perform_load(self, agent_name: str) -> Union[BaseAgent, App]:
"""Internal logic to load an agent"""
# Determine the directory to use for loading
if agent_name.startswith("__"):
@@ -208,16 +224,16 @@ class AgentLoader(BaseAgentLoader):
)
@override
def load_agent(self, agent_name: str) -> BaseAgent:
def load_agent(self, agent_name: str) -> Union[BaseAgent, App]:
"""Load an agent module (with caching & .env) and return its root_agent."""
if agent_name in self._agent_cache:
logger.debug("Returning cached agent for %s (async)", agent_name)
return self._agent_cache[agent_name]
logger.debug("Loading agent %s - not in cache.", agent_name)
agent = self._perform_load(agent_name)
self._agent_cache[agent_name] = agent
return agent
agent_or_app = self._perform_load(agent_name)
self._agent_cache[agent_name] = agent_or_app
return agent_or_app
@override
def list_agents(self) -> list[str]:
@@ -18,15 +18,17 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from typing import Union
from ...agents.base_agent import BaseAgent
from ...apps.app import App
class BaseAgentLoader(ABC):
"""Abstract base class for agent loaders."""
@abstractmethod
def load_agent(self, agent_name: str) -> BaseAgent:
def load_agent(self, agent_name: str) -> Union[BaseAgent, App]:
"""Loads an instance of an agent with the given name."""
@abstractmethod
+75 -9
View File
@@ -34,6 +34,7 @@ from .agents.invocation_context import new_invocation_context_id
from .agents.live_request_queue import LiveRequestQueue
from .agents.llm_agent import LlmAgent
from .agents.run_config import RunConfig
from .apps.app import App
from .artifacts.base_artifact_service import BaseArtifactService
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
from .auth.credential_service.base_credential_service import BaseCredentialService
@@ -91,8 +92,9 @@ class Runner:
def __init__(
self,
*,
app_name: str,
agent: BaseAgent,
app: Optional[App] = None,
app_name: Optional[str] = None,
agent: Optional[BaseAgent] = None,
plugins: Optional[List[BasePlugin]] = None,
artifact_service: Optional[BaseArtifactService] = None,
session_service: BaseSessionService,
@@ -101,23 +103,85 @@ class Runner:
):
"""Initializes the Runner.
Developers should provide either an `app` instance or both `app_name` and
`agent`. Providing a mix of `app` and `app_name`/`agent` will result in a
`ValueError`. Providing `app` is the recommended way to create a runner.
Args:
app_name: The application name of the runner.
agent: The root agent to run.
plugins: A list of plugins for the runner.
app_name: The application name of the runner. Required if `app` is not
provided.
agent: The root agent to run. Required if `app` is not provided.
app: An optional `App` instance. If provided, `app_name` and `agent`
should not be specified.
plugins: Deprecated. A list of plugins for the runner. Please use the
`app` argument to provide plugins instead.
artifact_service: The artifact service for the runner.
session_service: The session service for the runner.
memory_service: The memory service for the runner.
credential_service: The credential service for the runner.
Raises:
ValueError: If `app` is provided along with `app_name` or `plugins`, or
if `app` is not provided but either `app_name` or `agent` is missing.
"""
self.app_name = app_name
self.agent = agent
self.app_name, self.agent, plugins = self._validate_runner_params(
app, app_name, agent, plugins
)
self.artifact_service = artifact_service
self.session_service = session_service
self.memory_service = memory_service
self.credential_service = credential_service
self.plugin_manager = PluginManager(plugins=plugins)
def _validate_runner_params(
self,
app: Optional[App],
app_name: Optional[str],
agent: Optional[BaseAgent],
plugins: Optional[List[BasePlugin]],
) -> tuple[str, BaseAgent, Optional[List[BasePlugin]]]:
"""Validates and extracts runner parameters.
Args:
app: An optional `App` instance.
app_name: The application name of the runner.
agent: The root agent to run.
plugins: A list of plugins for the runner.
Returns:
A tuple containing (app_name, agent, plugins).
Raises:
ValueError: If parameters are invalid.
"""
if app:
if app_name:
raise ValueError(
'When app is provided, app_name should not be provided.'
)
if agent:
raise ValueError('When app is provided, agent should not be provided.')
if plugins:
raise ValueError(
'When app is provided, plugins should not be provided and should be'
' provided in the app instead.'
)
app_name = app.name
agent = app.root_agent
plugins = app.plugins
elif not app_name or not agent:
raise ValueError(
'Either app or both app_name and agent must be provided.'
)
if plugins:
warnings.warn(
'The `plugins` argument is deprecated. Please use the `app` argument'
' to provide plugins instead.',
DeprecationWarning,
)
return app_name, agent, plugins
def run(
self,
*,
@@ -658,10 +722,11 @@ class InMemoryRunner(Runner):
def __init__(
self,
agent: BaseAgent,
agent: Optional[BaseAgent] = None,
*,
app_name: str = 'InMemoryRunner',
app_name: Optional[str] = 'InMemoryRunner',
plugins: Optional[list[BasePlugin]] = None,
app: Optional[App] = None,
):
"""Initializes the InMemoryRunner.
@@ -676,6 +741,7 @@ class InMemoryRunner(Runner):
agent=agent,
artifact_service=InMemoryArtifactService(),
plugins=plugins,
app=app,
session_service=self._in_memory_session_service,
memory_service=InMemoryMemoryService(),
)
+4 -3
View File
@@ -26,6 +26,7 @@ from typing import List
from typing import Tuple
import click
from google.adk.agents.base_agent import BaseAgent
import google.adk.cli.cli as cli
import pytest
@@ -130,12 +131,12 @@ async def test_run_input_file_outputs(
artifact_service = cli.InMemoryArtifactService()
session_service = cli.InMemorySessionService()
credential_service = cli.InMemoryCredentialService()
dummy_root = types.SimpleNamespace(name="root")
dummy_root = BaseAgent(name="root")
session = await cli.run_input_file(
app_name="app",
user_id="user",
root_agent=dummy_root,
agent_or_app=dummy_root,
artifact_service=artifact_service,
session_service=session_service,
credential_service=credential_service,
@@ -205,7 +206,7 @@ async def test_run_interactively_whitespace_and_exit(
sess = await session_service.create_session(app_name="dummy", user_id="u")
artifact_service = cli.InMemoryArtifactService()
credential_service = cli.InMemoryCredentialService()
root_agent = types.SimpleNamespace(name="root")
root_agent = BaseAgent(name="root")
# fake user input: blank -> 'hello' -> 'exit'
answers = iter([" ", "hello", "exit"])