You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Upgrade ADK stack to use App instead in addition to root_agent
The convention: - If some fields(like plugin) are defined both at root_agent and app, then a error will be raised. - app code should be located within agent.py. - an instance named app should be created PiperOrigin-RevId: 803155804
This commit is contained in:
committed by
Copybara-Service
parent
14484065c6
commit
4df79dd5c9
+15
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
+145
@@ -0,0 +1,145 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.apps import App
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.plugins.base_plugin import BasePlugin
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def roll_die(sides: int, tool_context: ToolContext) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
result = random.randint(1, sides)
|
||||
if not 'rolls' in tool_context.state:
|
||||
tool_context.state['rolls'] = []
|
||||
|
||||
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
|
||||
return result
|
||||
|
||||
|
||||
async def check_prime(nums: list[int]) -> str:
|
||||
"""Check if a given list of numbers are prime.
|
||||
|
||||
Args:
|
||||
nums: The list of numbers to check.
|
||||
|
||||
Returns:
|
||||
A str indicating which number is prime.
|
||||
"""
|
||||
primes = set()
|
||||
for number in nums:
|
||||
number = int(number)
|
||||
if number <= 1:
|
||||
continue
|
||||
is_prime = True
|
||||
for i in range(2, int(number**0.5) + 1):
|
||||
if number % i == 0:
|
||||
is_prime = False
|
||||
break
|
||||
if is_prime:
|
||||
primes.add(number)
|
||||
return (
|
||||
'No prime numbers found.'
|
||||
if not primes
|
||||
else f"{', '.join(str(num) for num in primes)} are prime numbers."
|
||||
)
|
||||
|
||||
|
||||
root_agent = Agent(
|
||||
model='gemini-2.0-flash',
|
||||
name='hello_world_agent',
|
||||
description=(
|
||||
'hello world agent that can roll a dice of 8 sides and check prime'
|
||||
' numbers.'
|
||||
),
|
||||
instruction="""
|
||||
You roll dice and answer questions about the outcome of the dice rolls.
|
||||
You can roll dice of different sizes.
|
||||
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
|
||||
It is ok to discuss previous dice roles, and comment on the dice rolls.
|
||||
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
|
||||
You should never roll a die on your own.
|
||||
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
|
||||
You should not check prime numbers before calling the tool.
|
||||
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
|
||||
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
|
||||
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
|
||||
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
|
||||
3. When you respond, you must include the roll_die result from step 1.
|
||||
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
|
||||
You should not rely on the previous history on prime results.
|
||||
""",
|
||||
tools=[
|
||||
roll_die,
|
||||
check_prime,
|
||||
],
|
||||
# planner=BuiltInPlanner(
|
||||
# thinking_config=types.ThinkingConfig(
|
||||
# include_thoughts=True,
|
||||
# ),
|
||||
# ),
|
||||
generate_content_config=types.GenerateContentConfig(
|
||||
safety_settings=[
|
||||
types.SafetySetting( # avoid false alarm about rolling dice.
|
||||
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold=types.HarmBlockThreshold.OFF,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CountInvocationPlugin(BasePlugin):
|
||||
"""A custom plugin that counts agent and tool invocations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the plugin with counters."""
|
||||
super().__init__(name='count_invocation')
|
||||
self.agent_count: int = 0
|
||||
self.tool_count: int = 0
|
||||
self.llm_request_count: int = 0
|
||||
|
||||
async def before_agent_callback(
|
||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||
) -> None:
|
||||
"""Count agent runs."""
|
||||
self.agent_count += 1
|
||||
print(f'[Plugin] Agent run count: {self.agent_count}')
|
||||
|
||||
async def before_model_callback(
|
||||
self, *, callback_context: CallbackContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
"""Count LLM requests."""
|
||||
self.llm_request_count += 1
|
||||
print(f'[Plugin] LLM request count: {self.llm_request_count}')
|
||||
|
||||
|
||||
app = App(
|
||||
name='hello_world_app',
|
||||
root_agent=root_agent,
|
||||
plugins=[CountInvocationPlugin()],
|
||||
)
|
||||
Executable
+103
@@ -0,0 +1,103 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import agent
|
||||
from dotenv import load_dotenv
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.cli.utils import logs
|
||||
from google.adk.runners import InMemoryRunner
|
||||
from google.adk.sessions.session import Session
|
||||
from google.genai import types
|
||||
|
||||
load_dotenv(override=True)
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
|
||||
async def main():
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
runner = InMemoryRunner(
|
||||
agent=agent.root_agent,
|
||||
app_name=app_name,
|
||||
)
|
||||
session_11 = await runner.session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_1
|
||||
)
|
||||
|
||||
async def run_prompt(session: Session, new_message: str):
|
||||
content = types.Content(
|
||||
role='user', parts=[types.Part.from_text(text=new_message)]
|
||||
)
|
||||
print('** User says:', content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
|
||||
async def run_prompt_bytes(session: Session, new_message: str):
|
||||
content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_bytes(
|
||||
data=str.encode(new_message), mime_type='text/plain'
|
||||
)
|
||||
],
|
||||
)
|
||||
print('** User says:', content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
run_config=RunConfig(save_input_blobs_as_artifacts=True),
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
|
||||
async def check_rolls_in_state(rolls_size: int):
|
||||
session = await runner.session_service.get_session(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_11.id
|
||||
)
|
||||
assert len(session.state['rolls']) == rolls_size
|
||||
for roll in session.state['rolls']:
|
||||
assert roll > 0 and roll <= 100
|
||||
|
||||
start_time = time.time()
|
||||
print('Start time:', start_time)
|
||||
print('------------------------------------')
|
||||
await run_prompt(session_11, 'Hi')
|
||||
await run_prompt(session_11, 'Roll a die with 100 sides')
|
||||
await check_rolls_in_state(1)
|
||||
await run_prompt(session_11, 'Roll a die again with 100 sides.')
|
||||
await check_rolls_in_state(2)
|
||||
await run_prompt(session_11, 'What numbers did I got?')
|
||||
await run_prompt_bytes(session_11, 'Hi bytes')
|
||||
print(
|
||||
await runner.artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_11.id
|
||||
)
|
||||
)
|
||||
end_time = time.time()
|
||||
print('------------------------------------')
|
||||
print('End time:', end_time)
|
||||
print('Total time:', end_time - start_time)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .app import App
|
||||
|
||||
__all__ = [
|
||||
'App',
|
||||
]
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..plugins.base_plugin import BasePlugin
|
||||
from ..utils.feature_decorator import experimental
|
||||
|
||||
|
||||
@experimental
|
||||
class App(BaseModel):
|
||||
"""Represents an LLM-backed agentic application.
|
||||
|
||||
An `App` is the top-level container for an agentic system powered by LLMs.
|
||||
It manages a root agent (`root_agent`), which serves as the root of an agent
|
||||
tree, enabling coordination and communication across all agents in the
|
||||
hierarchy.
|
||||
The `plugins` are application-wide components that provide shared capabilities
|
||||
and services to the entire system.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
name: str
|
||||
"""The name of the application."""
|
||||
|
||||
root_agent: BaseAgent
|
||||
"""The root agent in the application. One app can only have one root agent."""
|
||||
|
||||
plugins: list[BasePlugin] = Field(default_factory=list)
|
||||
"""The plugins in the application."""
|
||||
@@ -50,10 +50,12 @@ from typing_extensions import override
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from . import agent_graph
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.run_config import RunConfig
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..apps.app import App
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..auth.credential_service.base_credential_service import BaseCredentialService
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
@@ -322,10 +324,17 @@ class AdkWebServer:
|
||||
envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir)
|
||||
if app_name in self.runner_dict:
|
||||
return self.runner_dict[app_name]
|
||||
root_agent = self.agent_loader.load_agent(app_name)
|
||||
agent_or_app = self.agent_loader.load_agent(app_name)
|
||||
agentic_app = None
|
||||
if isinstance(agent_or_app, BaseAgent):
|
||||
agentic_app = App(
|
||||
name=app_name,
|
||||
root_agent=agent_or_app,
|
||||
)
|
||||
else:
|
||||
agentic_app = agent_or_app
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
app=agentic_app,
|
||||
artifact_service=self.artifact_service,
|
||||
session_service=self.session_service,
|
||||
memory_service=self.memory_service,
|
||||
@@ -624,9 +633,10 @@ class AdkWebServer:
|
||||
invocations = evals.convert_session_to_eval_invocations(session)
|
||||
|
||||
# Populate the session with initial session state.
|
||||
initial_session_state = create_empty_state(
|
||||
self.agent_loader.load_agent(app_name)
|
||||
)
|
||||
agent_or_app = self.agent_loader.load_agent(app_name)
|
||||
if isinstance(agent_or_app, App):
|
||||
agent_or_app = agent_or_app.root_agent
|
||||
initial_session_state = create_empty_state(agent_or_app)
|
||||
|
||||
new_eval_case = EvalCase(
|
||||
eval_id=req.eval_id,
|
||||
|
||||
@@ -16,12 +16,15 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import click
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..apps.app import App
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from ..auth.credential_service.base_credential_service import BaseCredentialService
|
||||
@@ -43,15 +46,19 @@ class InputFile(BaseModel):
|
||||
async def run_input_file(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
root_agent: LlmAgent,
|
||||
agent_or_app: Union[LlmAgent, App],
|
||||
artifact_service: BaseArtifactService,
|
||||
session_service: BaseSessionService,
|
||||
credential_service: BaseCredentialService,
|
||||
input_path: str,
|
||||
) -> Session:
|
||||
app = (
|
||||
agent_or_app
|
||||
if isinstance(agent_or_app, App)
|
||||
else App(name=app_name, root_agent=agent_or_app)
|
||||
)
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=root_agent,
|
||||
app=app,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
credential_service=credential_service,
|
||||
@@ -79,15 +86,19 @@ async def run_input_file(
|
||||
|
||||
|
||||
async def run_interactively(
|
||||
root_agent: LlmAgent,
|
||||
root_agent_or_app: Union[LlmAgent, App],
|
||||
artifact_service: BaseArtifactService,
|
||||
session: Session,
|
||||
session_service: BaseSessionService,
|
||||
credential_service: BaseCredentialService,
|
||||
) -> None:
|
||||
app = (
|
||||
root_agent_or_app
|
||||
if isinstance(root_agent_or_app, App)
|
||||
else App(name=session.app_name, root_agent=root_agent_or_app)
|
||||
)
|
||||
runner = Runner(
|
||||
app_name=session.app_name,
|
||||
agent=root_agent,
|
||||
app=app,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
credential_service=credential_service,
|
||||
@@ -154,7 +165,7 @@ async def run_cli(
|
||||
session = await run_input_file(
|
||||
app_name=agent_folder_name,
|
||||
user_id=user_id,
|
||||
root_agent=root_agent,
|
||||
agent_or_app=root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
credential_service=credential_service,
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
from typing_extensions import override
|
||||
@@ -27,6 +28,7 @@ from typing_extensions import override
|
||||
from . import envs
|
||||
from ...agents import config_agent_utils
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...apps.app import App
|
||||
from ...utils.feature_decorator import experimental
|
||||
from .base_agent_loader import BaseAgentLoader
|
||||
|
||||
@@ -55,19 +57,25 @@ class AgentLoader(BaseAgentLoader):
|
||||
def __init__(self, agents_dir: str):
|
||||
self.agents_dir = agents_dir.rstrip("/")
|
||||
self._original_sys_path = None
|
||||
self._agent_cache: dict[str, BaseAgent] = {}
|
||||
self._agent_cache: dict[str, Union[BaseAgent, App]] = {}
|
||||
|
||||
def _load_from_module_or_package(
|
||||
self, agent_name: str
|
||||
) -> Optional[BaseAgent]:
|
||||
) -> Optional[Union[BaseAgent, App]]:
|
||||
# Load for case: Import "{agent_name}" (as a package or module)
|
||||
# Covers structures:
|
||||
# a) agents_dir/{agent_name}.py (with root_agent in the module)
|
||||
# b) agents_dir/{agent_name}/__init__.py (with root_agent in the package)
|
||||
try:
|
||||
module_candidate = importlib.import_module(agent_name)
|
||||
# Check for "app" first, then "root_agent"
|
||||
if hasattr(module_candidate, "app") and isinstance(
|
||||
module_candidate.app, App
|
||||
):
|
||||
logger.debug("Found app in %s", agent_name)
|
||||
return module_candidate.app
|
||||
# Check for "root_agent" directly in "{agent_name}" module/package
|
||||
if hasattr(module_candidate, "root_agent"):
|
||||
elif hasattr(module_candidate, "root_agent"):
|
||||
logger.debug("Found root_agent directly in %s", agent_name)
|
||||
if isinstance(module_candidate.root_agent, BaseAgent):
|
||||
return module_candidate.root_agent
|
||||
@@ -101,12 +109,20 @@ class AgentLoader(BaseAgentLoader):
|
||||
|
||||
return None
|
||||
|
||||
def _load_from_submodule(self, agent_name: str) -> Optional[BaseAgent]:
|
||||
def _load_from_submodule(
|
||||
self, agent_name: str
|
||||
) -> Optional[Union[BaseAgent], App]:
|
||||
# Load for case: Import "{agent_name}.agent" and look for "root_agent"
|
||||
# Covers structure: agents_dir/{agent_name}/agent.py (with root_agent defined in the module)
|
||||
try:
|
||||
module_candidate = importlib.import_module(f"{agent_name}.agent")
|
||||
if hasattr(module_candidate, "root_agent"):
|
||||
# Check for "app" first, then "root_agent"
|
||||
if hasattr(module_candidate, "app") and isinstance(
|
||||
module_candidate.app, App
|
||||
):
|
||||
logger.debug("Found app in %s.agent", agent_name)
|
||||
return module_candidate.app
|
||||
elif hasattr(module_candidate, "root_agent"):
|
||||
logger.info("Found root_agent in %s.agent", agent_name)
|
||||
if isinstance(module_candidate.root_agent, BaseAgent):
|
||||
return module_candidate.root_agent
|
||||
@@ -168,7 +184,7 @@ class AgentLoader(BaseAgentLoader):
|
||||
) + e.args[1:]
|
||||
raise e
|
||||
|
||||
def _perform_load(self, agent_name: str) -> BaseAgent:
|
||||
def _perform_load(self, agent_name: str) -> Union[BaseAgent, App]:
|
||||
"""Internal logic to load an agent"""
|
||||
# Determine the directory to use for loading
|
||||
if agent_name.startswith("__"):
|
||||
@@ -208,16 +224,16 @@ class AgentLoader(BaseAgentLoader):
|
||||
)
|
||||
|
||||
@override
|
||||
def load_agent(self, agent_name: str) -> BaseAgent:
|
||||
def load_agent(self, agent_name: str) -> Union[BaseAgent, App]:
|
||||
"""Load an agent module (with caching & .env) and return its root_agent."""
|
||||
if agent_name in self._agent_cache:
|
||||
logger.debug("Returning cached agent for %s (async)", agent_name)
|
||||
return self._agent_cache[agent_name]
|
||||
|
||||
logger.debug("Loading agent %s - not in cache.", agent_name)
|
||||
agent = self._perform_load(agent_name)
|
||||
self._agent_cache[agent_name] = agent
|
||||
return agent
|
||||
agent_or_app = self._perform_load(agent_name)
|
||||
self._agent_cache[agent_name] = agent_or_app
|
||||
return agent_or_app
|
||||
|
||||
@override
|
||||
def list_agents(self) -> list[str]:
|
||||
|
||||
@@ -18,15 +18,17 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Union
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
from ...apps.app import App
|
||||
|
||||
|
||||
class BaseAgentLoader(ABC):
|
||||
"""Abstract base class for agent loaders."""
|
||||
|
||||
@abstractmethod
|
||||
def load_agent(self, agent_name: str) -> BaseAgent:
|
||||
def load_agent(self, agent_name: str) -> Union[BaseAgent, App]:
|
||||
"""Loads an instance of an agent with the given name."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -34,6 +34,7 @@ from .agents.invocation_context import new_invocation_context_id
|
||||
from .agents.live_request_queue import LiveRequestQueue
|
||||
from .agents.llm_agent import LlmAgent
|
||||
from .agents.run_config import RunConfig
|
||||
from .apps.app import App
|
||||
from .artifacts.base_artifact_service import BaseArtifactService
|
||||
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from .auth.credential_service.base_credential_service import BaseCredentialService
|
||||
@@ -91,8 +92,9 @@ class Runner:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
agent: BaseAgent,
|
||||
app: Optional[App] = None,
|
||||
app_name: Optional[str] = None,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
plugins: Optional[List[BasePlugin]] = None,
|
||||
artifact_service: Optional[BaseArtifactService] = None,
|
||||
session_service: BaseSessionService,
|
||||
@@ -101,23 +103,85 @@ class Runner:
|
||||
):
|
||||
"""Initializes the Runner.
|
||||
|
||||
Developers should provide either an `app` instance or both `app_name` and
|
||||
`agent`. Providing a mix of `app` and `app_name`/`agent` will result in a
|
||||
`ValueError`. Providing `app` is the recommended way to create a runner.
|
||||
|
||||
Args:
|
||||
app_name: The application name of the runner.
|
||||
agent: The root agent to run.
|
||||
plugins: A list of plugins for the runner.
|
||||
app_name: The application name of the runner. Required if `app` is not
|
||||
provided.
|
||||
agent: The root agent to run. Required if `app` is not provided.
|
||||
app: An optional `App` instance. If provided, `app_name` and `agent`
|
||||
should not be specified.
|
||||
plugins: Deprecated. A list of plugins for the runner. Please use the
|
||||
`app` argument to provide plugins instead.
|
||||
artifact_service: The artifact service for the runner.
|
||||
session_service: The session service for the runner.
|
||||
memory_service: The memory service for the runner.
|
||||
credential_service: The credential service for the runner.
|
||||
|
||||
Raises:
|
||||
ValueError: If `app` is provided along with `app_name` or `plugins`, or
|
||||
if `app` is not provided but either `app_name` or `agent` is missing.
|
||||
"""
|
||||
self.app_name = app_name
|
||||
self.agent = agent
|
||||
self.app_name, self.agent, plugins = self._validate_runner_params(
|
||||
app, app_name, agent, plugins
|
||||
)
|
||||
self.artifact_service = artifact_service
|
||||
self.session_service = session_service
|
||||
self.memory_service = memory_service
|
||||
self.credential_service = credential_service
|
||||
self.plugin_manager = PluginManager(plugins=plugins)
|
||||
|
||||
def _validate_runner_params(
|
||||
self,
|
||||
app: Optional[App],
|
||||
app_name: Optional[str],
|
||||
agent: Optional[BaseAgent],
|
||||
plugins: Optional[List[BasePlugin]],
|
||||
) -> tuple[str, BaseAgent, Optional[List[BasePlugin]]]:
|
||||
"""Validates and extracts runner parameters.
|
||||
|
||||
Args:
|
||||
app: An optional `App` instance.
|
||||
app_name: The application name of the runner.
|
||||
agent: The root agent to run.
|
||||
plugins: A list of plugins for the runner.
|
||||
|
||||
Returns:
|
||||
A tuple containing (app_name, agent, plugins).
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid.
|
||||
"""
|
||||
if app:
|
||||
if app_name:
|
||||
raise ValueError(
|
||||
'When app is provided, app_name should not be provided.'
|
||||
)
|
||||
if agent:
|
||||
raise ValueError('When app is provided, agent should not be provided.')
|
||||
if plugins:
|
||||
raise ValueError(
|
||||
'When app is provided, plugins should not be provided and should be'
|
||||
' provided in the app instead.'
|
||||
)
|
||||
app_name = app.name
|
||||
agent = app.root_agent
|
||||
plugins = app.plugins
|
||||
elif not app_name or not agent:
|
||||
raise ValueError(
|
||||
'Either app or both app_name and agent must be provided.'
|
||||
)
|
||||
|
||||
if plugins:
|
||||
warnings.warn(
|
||||
'The `plugins` argument is deprecated. Please use the `app` argument'
|
||||
' to provide plugins instead.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return app_name, agent, plugins
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
@@ -658,10 +722,11 @@ class InMemoryRunner(Runner):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: BaseAgent,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
*,
|
||||
app_name: str = 'InMemoryRunner',
|
||||
app_name: Optional[str] = 'InMemoryRunner',
|
||||
plugins: Optional[list[BasePlugin]] = None,
|
||||
app: Optional[App] = None,
|
||||
):
|
||||
"""Initializes the InMemoryRunner.
|
||||
|
||||
@@ -676,6 +741,7 @@ class InMemoryRunner(Runner):
|
||||
agent=agent,
|
||||
artifact_service=InMemoryArtifactService(),
|
||||
plugins=plugins,
|
||||
app=app,
|
||||
session_service=self._in_memory_session_service,
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import click
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
import google.adk.cli.cli as cli
|
||||
import pytest
|
||||
|
||||
@@ -130,12 +131,12 @@ async def test_run_input_file_outputs(
|
||||
artifact_service = cli.InMemoryArtifactService()
|
||||
session_service = cli.InMemorySessionService()
|
||||
credential_service = cli.InMemoryCredentialService()
|
||||
dummy_root = types.SimpleNamespace(name="root")
|
||||
dummy_root = BaseAgent(name="root")
|
||||
|
||||
session = await cli.run_input_file(
|
||||
app_name="app",
|
||||
user_id="user",
|
||||
root_agent=dummy_root,
|
||||
agent_or_app=dummy_root,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
credential_service=credential_service,
|
||||
@@ -205,7 +206,7 @@ async def test_run_interactively_whitespace_and_exit(
|
||||
sess = await session_service.create_session(app_name="dummy", user_id="u")
|
||||
artifact_service = cli.InMemoryArtifactService()
|
||||
credential_service = cli.InMemoryCredentialService()
|
||||
root_agent = types.SimpleNamespace(name="root")
|
||||
root_agent = BaseAgent(name="root")
|
||||
|
||||
# fake user input: blank -> 'hello' -> 'exit'
|
||||
answers = iter([" ", "hello", "exit"])
|
||||
|
||||
Reference in New Issue
Block a user