Merge branch 'main' into fix/missing-path-level-parameters

This commit is contained in:
Jeffrey Mak
2025-05-14 10:57:06 -04:00
committed by GitHub
84 changed files with 3139 additions and 1305 deletions
+27 -30
View File
@@ -4,21 +4,21 @@ We'd love to accept your patches and contributions to this project.
## Table of Contents
- [Before you begin](#before-you-begin)
- [Sign our Contributor License Agreement](#sign-our-contributor-license-agreement)
- [Before you begin](#before-you-begin)
- [Sign our Contributor License Agreement](#sign-our-contributor-license-agreement)
- [Review our community guidelines](#review-our-community-guidelines)
- [Contribution workflow](#contribution-workflow)
- [Finding Issues to Work On](#finding-issues-to-work-on)
- [Requirement for PRs](#requirement-for-prs)
- [Large or Complex Changes](#large-or-complex-changes)
- [Testing Requirements](#testing-requirements)
- [Unit Tests](#unit-tests)
- [End-to-End (E2E) Tests](#manual-end-to-end-e2e-tests)
- [Documentation](#documentation)
- [Development Setup](#development-setup)
- [Contribution workflow](#contribution-workflow)
- [Finding Issues to Work On](#finding-issues-to-work-on)
- [Requirement for PRs](#requirement-for-prs)
- [Large or Complex Changes](#large-or-complex-changes)
- [Testing Requirements](#testing-requirements)
- [Unit Tests](#unit-tests)
- [End-to-End (E2E) Tests](#manual-end-to-end-e2e-tests)
- [Documentation](#documentation)
- [Development Setup](#development-setup)
- [Code reviews](#code-reviews)
## Before you begin
### Sign our Contributor License Agreement
@@ -44,13 +44,13 @@ This project follows
### Finding Issues to Work On
- Browse issues labeled **`good first issue`** (newcomer-friendly) or **`help wanted`** (general contributions).
- Browse issues labeled **`good first issue`** (newcomer-friendly) or **`help wanted`** (general contributions).
- For other issues, please kindly ask before contributing to avoid duplication.
### Requirement for PRs
- All PRs, other than small documentation or typo fixes, should have a Issue assoicated. If not, please create one.
- All PRs, other than small documentation or typo fixes, should have a Issue assoicated. If not, please create one.
- Small, focused PRs. Keep changes minimal—one concern per PR.
- For bug fixes or features, please provide logs or screenshot after the fix is applied to help reviewers better understand the fix.
- Please include a `testing plan` section in your PR to talk about how you will test. This will save time for PR review. See `Testing Requirements` section for more details.
@@ -72,12 +72,12 @@ Please add or update unit tests for your change. Please include a summary of pas
Requirements for unit tests:
- **Coverage:** Cover new features, edge cases, error conditions, and typical use cases.
- **Location:** Add or update tests under `tests/unittests/`, following existing naming conventions (e.g., `test_<module>_<feature>.py`).
- **Framework:** Use `pytest`. Tests should be:
- Fast and isolated.
- Written clearly with descriptive names.
- Free of external dependencies (use mocks or fixtures as needed).
- **Coverage:** Cover new features, edge cases, error conditions, and typical use cases.
- **Location:** Add or update tests under `tests/unittests/`, following existing naming conventions (e.g., `test_<module>_<feature>.py`).
- **Framework:** Use `pytest`. Tests should be:
- Fast and isolated.
- Written clearly with descriptive names.
- Free of external dependencies (use mocks or fixtures as needed).
- **Quality:** Aim for high readability and maintainability; include docstrings or comments for complex scenarios.
#### Manual End-to-End (E2E) Tests
@@ -86,15 +86,15 @@ Manual E2E tests ensure integrated flows work as intended. Your tests should cov
Depending on your change:
- **ADK Web:**
- Use the `adk web` to verify functionality.
- Capture and attach relevant screenshots demonstrating the UI/UX changes or outputs.
- **ADK Web:**
- Use the `adk web` to verify functionality.
- Capture and attach relevant screenshots demonstrating the UI/UX changes or outputs.
- Label screenshots clearly in your PR description.
- **Runner:**
- Provide the testing setup. For example, the agent definition, and the runner setup.
- Execute the `runner` tool to reproduce workflows.
- Include the command used and console output showing test results.
- Execute the `runner` tool to reproduce workflows.
- Include the command used and console output showing test results.
- Highlight sections of the log that directly relate to your change.
### Documentation
@@ -117,19 +117,16 @@ For any changes that impact user-facing documentation (guides, API reference, tu
```shell
source .venv/bin/activate
```
**windows**
```shell
source .\.venv\Scripts\activate
```
```shell
pip install uv
```
3. **Install dependencies:**
```shell
pip install uv
uv sync --all-extras
```
4. **Run unit tests:**
@@ -20,7 +20,6 @@ from dotenv import load_dotenv
from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.application_integration_tool import ApplicationIntegrationToolset
# Load environment variables from .env file
load_dotenv()
@@ -29,12 +28,12 @@ connection_project = os.getenv("CONNECTION_PROJECT")
connection_location = os.getenv("CONNECTION_LOCATION")
jira_tool = ApplicationIntegrationToolset(
jira_toolset = ApplicationIntegrationToolset(
project=connection_project,
location=connection_location,
connection=connection_name,
entity_operations={"Issues": [], "Projects": []},
tool_name="jira_issue_manager",
tool_name_prefix="jira_issue_manager",
)
root_agent = LlmAgent(
@@ -46,5 +45,5 @@ root_agent = LlmAgent(
If there is an error in the tool response, understand the error and try and see if you can fix the error and then and execute the tool again. For example if a variable or parameter is missing, try and see if you can find it in the request or user query or default it and then execute the tool again or check for other tools that could give you the details.
If there are any math operations like count or max, min in the user request, call the tool to get the data and perform the math operations and then return the result in the response. For example for maximum, fetch the list and then do the math operation.
""",
tools=jira_tool.get_tools(),
tools=[jira_toolset],
)
+13 -19
View File
@@ -16,7 +16,7 @@ import os
from dotenv import load_dotenv
from google.adk import Agent
from google.adk.tools.google_api_tool import bigquery_tool_set
from google.adk.tools.google_api_tool import bigquery_toolset
# Load environment variables from .env file
load_dotenv()
@@ -24,19 +24,20 @@ load_dotenv()
# Access the variable
oauth_client_id = os.getenv("OAUTH_CLIENT_ID")
oauth_client_secret = os.getenv("OAUTH_CLIENT_SECRET")
bigquery_tool_set.configure_auth(oauth_client_id, oauth_client_secret)
bigquery_toolset.configure_auth(oauth_client_id, oauth_client_secret)
bigquery_datasets_list = bigquery_tool_set.get_tool("bigquery_datasets_list")
bigquery_datasets_get = bigquery_tool_set.get_tool("bigquery_datasets_get")
bigquery_datasets_insert = bigquery_tool_set.get_tool(
"bigquery_datasets_insert"
tools_to_expose = [
"bigquery_datasets_list",
"bigquery_datasets_get",
"bigquery_datasets_insert",
"bigquery_tables_list",
"bigquery_tables_get",
"bigquery_tables_insert",
]
bigquery_toolset.set_tool_filter(
lambda tool, ctx=None: tool.name in tools_to_expose
)
bigquery_tables_list = bigquery_tool_set.get_tool("bigquery_tables_list")
bigquery_tables_get = bigquery_tool_set.get_tool("bigquery_tables_get")
bigquery_tables_insert = bigquery_tool_set.get_tool("bigquery_tables_insert")
root_agent = Agent(
model="gemini-2.0-flash",
name="bigquery_agent",
@@ -73,12 +74,5 @@ root_agent = Agent(
{userInfo?}
</User>
""",
tools=[
bigquery_datasets_list,
bigquery_datasets_get,
bigquery_datasets_insert,
bigquery_tables_list,
bigquery_tables_get,
bigquery_tables_insert,
],
tools=[bigquery_toolset],
)
+15
View File
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
+198
View File
@@ -0,0 +1,198 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk import Agent
from google.adk.planners import BuiltInPlanner
from google.adk.planners import PlanReActPlanner
from google.adk.tools.tool_context import ToolContext
from google.genai import types
def roll_die(sides: int, tool_context: ToolContext) -> int:
"""Roll a die and return the rolled result.
Args:
sides: The integer number of sides the die has.
Returns:
An integer of the result of rolling the die.
"""
result = random.randint(1, sides)
if not 'rolls' in tool_context.state:
tool_context.state['rolls'] = []
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
return result
async def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime.
Args:
nums: The list of numbers to check.
Returns:
A str indicating which number is prime.
"""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
'No prime numbers found.'
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
async def before_agent_callback(callback_context):
print('@before_agent_callback')
return None
async def after_agent_callback(callback_context):
print('@after_agent_callback')
return None
async def before_model_callback(callback_context, llm_request):
print('@before_model_callback')
return None
async def after_model_callback(callback_context, llm_response):
print('@after_model_callback')
return None
def after_agent_cb1(callback_context):
print('@after_agent_cb1')
def after_agent_cb2(callback_context):
print('@after_agent_cb2')
# ModelContent (or Content with role set to 'model') must be returned.
# Otherwise, the event will be excluded from the context in the next turn.
return types.ModelContent(
parts=[
types.Part(
text='(stopped) after_agent_cb2',
),
],
)
def after_agent_cb3(callback_context):
print('@after_agent_cb3')
def before_agent_cb1(callback_context):
print('@before_agent_cb1')
def before_agent_cb2(callback_context):
print('@before_agent_cb2')
def before_agent_cb3(callback_context):
print('@before_agent_cb3')
def before_tool_cb1(tool, args, tool_context):
print('@before_tool_cb1')
def before_tool_cb2(tool, args, tool_context):
print('@before_tool_cb2')
def before_tool_cb3(tool, args, tool_context):
print('@before_tool_cb3')
def after_tool_cb1(tool, args, tool_context, tool_response):
print('@after_tool_cb1')
def after_tool_cb2(tool, args, tool_context, tool_response):
print('@after_tool_cb2')
return {'test': 'after_tool_cb2', 'response': tool_response}
def after_tool_cb3(tool, args, tool_context, tool_response):
print('@after_tool_cb3')
root_agent = Agent(
model='gemini-2.0-flash-exp',
name='data_processing_agent',
description=(
'hello world agent that can roll a dice of 8 sides and check prime'
' numbers.'
),
instruction="""
You roll dice and answer questions about the outcome of the dice rolls.
You can roll dice of different sizes.
You can use multiple tools in parallel by calling functions in parallel(in one request and in one round).
It is ok to discuss previous dice roles, and comment on the dice rolls.
When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string.
You should never roll a die on your own.
When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string.
You should not check prime numbers before calling the tool.
When you are asked to roll a die and check prime numbers, you should always make the following two function calls:
1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool.
2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result.
2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list.
3. When you respond, you must include the roll_die result from step 1.
You should always perform the previous 3 steps when asking for a roll and checking prime numbers.
You should not rely on the previous history on prime results.
""",
tools=[
roll_die,
check_prime,
],
# planner=BuiltInPlanner(
# thinking_config=types.ThinkingConfig(
# include_thoughts=True,
# ),
# ),
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
before_agent_callback=[
before_agent_cb1,
before_agent_cb2,
before_agent_cb3,
],
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
before_model_callback=before_model_callback,
after_model_callback=after_model_callback,
before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3],
after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3],
)
+145
View File
@@ -0,0 +1,145 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import warnings
import agent
from dotenv import load_dotenv
from google.adk import Runner
from google.adk.agents.run_config import RunConfig
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.utils import logs
from google.adk.sessions import InMemorySessionService
from google.adk.sessions import Session
from google.genai import types
load_dotenv(override=True)
warnings.filterwarnings('ignore', category=UserWarning)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
app_name=app_name, user_id=user_id_1
)
async def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
role='user',
parts=[
types.Part.from_bytes(
data=str.encode(new_message), mime_type='text/plain'
)
],
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(session_11, 'Roll a die with 100 sides')
await run_prompt(session_11, 'Roll a die again with 100 sides.')
await run_prompt(session_11, 'What numbers did I got?')
await run_prompt_bytes(session_11, 'Hi bytes')
print(
await artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
)
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
def main_sync():
app_name = 'my_app'
user_id_1 = 'user1'
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent,
artifact_service=artifact_service,
session_service=session_service,
)
session_11 = session_service.create_session(
app_name=app_name, user_id=user_id_1
)
def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
for event in runner.run(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
run_prompt(session_11, 'Hi')
run_prompt(session_11, 'Roll a die with 100 sides.')
run_prompt(session_11, 'Roll a die again with 100 sides.')
run_prompt(session_11, 'What numbers did I got?')
end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)
if __name__ == '__main__':
print('--------------ASYNC--------------------')
asyncio.run(main())
print('--------------SYNC--------------------')
main_sync()
-87
View File
@@ -65,83 +65,6 @@ async def check_prime(nums: list[int]) -> str:
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
async def before_agent_callback(callback_context):
print('@before_agent_callback')
return None
async def after_agent_callback(callback_context):
print('@after_agent_callback')
return None
async def before_model_callback(callback_context, llm_request):
print('@before_model_callback')
return None
async def after_model_callback(callback_context, llm_response):
print('@after_model_callback')
return None
def after_agent_cb1(callback_context):
print('@after_agent_cb1')
def after_agent_cb2(callback_context):
print('@after_agent_cb2')
return types.Content(
parts=[
types.Part(
text='(stopped) after_agent_cb2',
),
],
)
def after_agent_cb3(callback_context):
print('@after_agent_cb3')
def before_agent_cb1(callback_context):
print('@before_agent_cb1')
def before_agent_cb2(callback_context):
print('@before_agent_cb2')
def before_agent_cb3(callback_context):
print('@before_agent_cb3')
def before_tool_cb1(tool, args, tool_context):
print('@before_tool_cb1')
def before_tool_cb2(tool, args, tool_context):
print('@before_tool_cb2')
def before_tool_cb3(tool, args, tool_context):
print('@before_tool_cb3')
def after_tool_cb1(tool, args, tool_context, tool_response):
print('@after_tool_cb1')
def after_tool_cb2(tool, args, tool_context, tool_response):
print('@after_tool_cb2')
return {'test': 'after_tool_cb2', 'response': tool_response}
def after_tool_cb3(tool, args, tool_context, tool_response):
print('@after_tool_cb3')
root_agent = Agent(
model='gemini-2.0-flash-exp',
name='data_processing_agent',
@@ -183,14 +106,4 @@ root_agent = Agent(
),
]
),
before_agent_callback=[
before_agent_cb1,
before_agent_cb2,
before_agent_cb3,
],
after_agent_callback=[after_agent_cb1, after_agent_cb2, after_agent_cb3],
before_model_callback=before_model_callback,
after_model_callback=after_model_callback,
before_tool_callback=[before_tool_cb1, before_tool_cb2, before_tool_cb3],
after_tool_callback=[after_tool_cb1, after_tool_cb2, after_tool_cb3],
)
+6 -2
View File
@@ -34,8 +34,12 @@ root_agent = LlmAgent(
],
),
# don't want agent to do write operation
tool_predicate=lambda tool, ctx=None: tool.name
not in ('write_file', 'edit_file', 'create_directory', 'move_file'),
tool_filter=[
'write_file',
'edit_file',
'create_directory',
'move_file',
],
)
],
)
+18 -7
View File
@@ -12,20 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from datetime import datetime
from google.adk import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.tools.load_memory_tool import load_memory_tool
from google.adk.tools.preload_memory_tool import preload_memory_tool
from google.genai import types
def update_current_time(callback_context: CallbackContext):
callback_context.state['_time'] = datetime.now().isoformat()
root_agent = Agent(
model='gemini-2.0-flash-exp',
model='gemini-2.0-flash-001',
name='memory_agent',
description='agent that have access to memory tools.',
instruction="""
You are an agent that help user answer questions.
""",
tools=[load_memory_tool, preload_memory_tool],
before_agent_callback=update_current_time,
instruction="""\
You are an agent that help user answer questions.
Current time: {_time}
""",
tools=[
load_memory_tool,
preload_memory_tool,
],
)
+111
View File
@@ -0,0 +1,111 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from datetime import datetime
from datetime import timedelta
from typing import cast
import warnings
import agent
from dotenv import load_dotenv
from google.adk.cli.utils import logs
from google.adk.runners import InMemoryRunner
from google.adk.sessions import Session
from google.genai import types
load_dotenv(override=True)
warnings.filterwarnings('ignore', category=UserWarning)
logs.log_to_tmp_folder()
async def main():
app_name = 'my_app'
user_id_1 = 'user1'
runner = InMemoryRunner(
app_name=app_name,
agent=agent.root_agent,
)
async def run_prompt(session: Session, new_message: str) -> Session:
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if not event.content or not event.content.parts:
continue
if event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
elif event.content.parts[0].function_call:
print(
f'** {event.author}: fc /'
f' {event.content.parts[0].function_call.name} /'
f' {event.content.parts[0].function_call.args}\n'
)
elif event.content.parts[0].function_response:
print(
f'** {event.author}: fr /'
f' {event.content.parts[0].function_response.name} /'
f' {event.content.parts[0].function_response.response}\n'
)
return cast(
Session,
runner.session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session.id
),
)
session_1 = runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
print(f'----Session to create memory: {session_1.id} ----------------------')
session_1 = await run_prompt(session_1, 'Hi')
session_1 = await run_prompt(session_1, 'My name is Jack')
session_1 = await run_prompt(session_1, 'I like badminton.')
session_1 = await run_prompt(
session_1,
f'I ate a burger on {(datetime.now() - timedelta(days=1)).date()}.',
)
session_1 = await run_prompt(
session_1,
f'I ate a banana on {(datetime.now() - timedelta(days=2)).date()}.',
)
print('Saving session to memory service...')
if runner.memory_service:
await runner.memory_service.add_session_to_memory(session_1)
print('-------------------------------------------------------------------')
session_2 = runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)
print(f'----Session to use memory: {session_2.id} ----------------------')
session_2 = await run_prompt(session_2, 'Hi')
session_2 = await run_prompt(session_2, 'What do I like to do?')
# ** memory_agent: You like badminton.
session_2 = await run_prompt(session_2, 'When did I say that?')
# ** memory_agent: You said you liked badminton on ...
session_2 = await run_prompt(session_2, 'What did I eat yesterday?')
# ** memory_agent: You ate a burger yesterday...
print('-------------------------------------------------------------------')
if __name__ == '__main__':
asyncio.run(main())
@@ -27,7 +27,7 @@ from google.adk.auth import AuthCredential
from google.adk.auth import AuthCredentialTypes
from google.adk.auth import OAuth2Auth
from google.adk.tools import ToolContext
from google.adk.tools.google_api_tool import calendar_tool_set
from google.adk.tools.google_api_tool import calendar_toolset
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
@@ -42,12 +42,12 @@ oauth_client_secret = os.getenv("OAUTH_CLIENT_SECRET")
SCOPES = ["https://www.googleapis.com/auth/calendar"]
calendar_tool_set.configure_auth(
calendar_toolset.configure_auth(
client_id=oauth_client_id, client_secret=oauth_client_secret
)
get_calendar_events = calendar_tool_set.get_tool("calendar_events_get")
# list_calendar_events = calendar_tool_set.get_tool("calendar_events_list")
get_calendar_events = calendar_toolset.get_tool("calendar_events_get")
# list_calendar_events = calendar_toolset.get_tool("calendar_events_list")
# you can replace below customized list_calendar_events tool with above ADK
# build-in google calendar tool which is commented for now to acheive same
# effect.
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,94 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.genai import types
# --- Roll Die Sub-Agent ---
def roll_die(sides: int) -> int:
"""Roll a die and return the rolled result."""
return random.randint(1, sides)
roll_agent = LlmAgent(
name="roll_agent",
description="Handles rolling dice of different sizes.",
model="gemini-2.0-flash-exp",
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)
def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime."""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
"No prime numbers found."
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)
prime_agent = LlmAgent(
name="prime_agent",
description="Handles checking if numbers are prime.",
model="gemini-2.0-flash-exp",
instruction="""
You are responsible for checking whether numbers are prime.
When asked to check primes, you must call the check_prime tool with a list of integers.
Never attempt to determine prime numbers manually.
Return the prime number results to the root agent.
""",
tools=[check_prime],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)
root_agent = SequentialAgent(
name="code_pipeline_agent",
sub_agents=[roll_agent, prime_agent],
# The agents will run in the order provided: roll_agent -> prime_agent
)
+9 -6
View File
@@ -15,10 +15,10 @@ classifiers = [ # List of https://pypi.org/classifiers/
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Operating System :: OS Independent",
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License",
@@ -31,7 +31,8 @@ dependencies = [
"google-api-python-client>=2.157.0", # Google API client discovery
"google-cloud-aiplatform>=1.87.0", # For VertexAI integrations, e.g. example store.
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
"google-cloud-speech>=2.30.0", # For Audo Transcription
"google-cloud-speech>=2.30.0", # For Audio Transcription
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.14.0", # Google GenAI SDK
"graphviz>=0.20.2", # Graphviz for graph rendering
@@ -83,7 +84,8 @@ test = [
"langchain-community>=0.3.17",
"langgraph>=0.2.60", # For LangGraphAgent
"litellm>=1.63.11", # For LiteLLM tests
"llama-index-readers-file>=0.4.0", # for retrieval tests
"llama-index-readers-file>=0.4.0", # For retrieval tests
"pytest-asyncio>=0.25.0",
"pytest-mock>=3.14.0",
"pytest-xdist>=3.6.1",
@@ -108,7 +110,7 @@ extensions = [
"docker>=7.0.0", # For ContainerCodeExecutor
"langgraph>=0.2.60", # For LangGraphAgent
"litellm>=1.63.11", # For LiteLLM support
"llama-index-readers-file>=0.4.0", # for retrieval usings LlamaIndex.
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
"lxml>=5.3.0", # For load_web_page tool.
]
@@ -146,6 +148,7 @@ name = "google.adk"
[tool.isort]
profile = "google"
single_line_exclusions = []
known_third_party = ["google.adk"]
[tool.pytest.ini_options]
+27 -15
View File
@@ -14,16 +14,15 @@
from __future__ import annotations
import inspect
import logging
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Literal,
Optional,
Union,
)
from typing import Any
from typing import AsyncGenerator
from typing import Awaitable
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Union
from google.genai import types
from pydantic import BaseModel
@@ -96,7 +95,9 @@ AfterToolCallback: TypeAlias = Union[
list[_SingleAfterToolCallback],
]
InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
InstructionProvider: TypeAlias = Callable[
[ReadonlyContext], Union[str, Awaitable[str]]
]
ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
ExamplesUnion = Union[list[Example], BaseExampleProvider]
@@ -149,7 +150,12 @@ class LlmAgent(BaseAgent):
# LLM-based agent transfer configs - Start
disallow_transfer_to_parent: bool = False
"""Disallows LLM-controlled transferring to the parent agent."""
"""Disallows LLM-controlled transferring to the parent agent.
NOTE: Setting this as True also prevents this agent to continue reply to the
end-user. This behavior prevents one-way transfer, in which end-user may be
stuck with one agent that cannot transfer to other agents in the agent tree.
"""
disallow_transfer_to_peers: bool = False
"""Disallows LLM-controlled transferring to the peer agents."""
# LLM-based agent transfer configs - End
@@ -302,7 +308,7 @@ class LlmAgent(BaseAgent):
ancestor_agent = ancestor_agent.parent_agent
raise ValueError(f'No model found for {self.name}.')
def canonical_instruction(self, ctx: ReadonlyContext) -> str:
async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct instruction for this agent.
This method is only for use by Agent Development Kit.
@@ -310,9 +316,12 @@ class LlmAgent(BaseAgent):
if isinstance(self.instruction, str):
return self.instruction
else:
return self.instruction(ctx)
instruction = self.instruction(ctx)
if inspect.isawaitable(instruction):
instruction = await instruction
return instruction
def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
"""The resolved self.instruction field to construct global instruction.
This method is only for use by Agent Development Kit.
@@ -320,7 +329,10 @@ class LlmAgent(BaseAgent):
if isinstance(self.global_instruction, str):
return self.global_instruction
else:
return self.global_instruction(ctx)
global_instruction = self.global_instruction(ctx)
if inspect.isawaitable(global_instruction):
global_instruction = await global_instruction
return global_instruction
async def canonical_tools(
self, ctx: ReadonlyContext = None
+1 -1
View File
@@ -58,5 +58,5 @@ class LoopAgent(BaseAgent):
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
raise NotImplementedError('The behavior for run_live is not defined yet.')
raise NotImplementedError('This is not supported yet for LoopAgent.')
yield # AsyncGenerator requires having at least one yield statement
+7
View File
@@ -94,3 +94,10 @@ class ParallelAgent(BaseAgent):
agent_runs = [agent.run_async(ctx) for agent in self.sub_agents]
async for event in _merge_agent_run(agent_runs):
yield event
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
raise NotImplementedError("This is not supported yet for ParallelAgent.")
yield # AsyncGenerator requires having at least one yield statement
+31
View File
@@ -23,6 +23,7 @@ from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from .base_agent import BaseAgent
from .llm_agent import LlmAgent
class SequentialAgent(BaseAgent):
@@ -40,6 +41,36 @@ class SequentialAgent(BaseAgent):
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Implementation for live SequentialAgent.
Compared to non-live case, live agents process a continous streams of audio
or video, so it doesn't have a way to tell if it's finished and should pass
to next agent or not. So we introduce a task_compelted() function so the
model can call this function to signal that it's finished the task and we
can move on to next agent.
Args:
ctx: The invocation context of the agent.
"""
# There is no way to know if it's using live during init phase so we have to init it here
for sub_agent in self.sub_agents:
# add tool
def task_completed():
"""
Signals that the model has successfully completed the user's question
or task.
"""
return "Task completion signaled."
if isinstance(sub_agent, LlmAgent):
# Use function name to dedupe.
if task_completed.__name__ not in sub_agent.tools:
sub_agent.tools.append(task_completed)
sub_agent.instruction += f"""If you finished the user' request
according to its description, call {task_completed.__name__} function
to exit so the next agents can take over. When calling this function,
do not generate any text other than the function call.'"""
for sub_agent in self.sub_agents:
async for event in sub_agent.run_live(ctx):
yield event
+6 -1
View File
@@ -15,13 +15,18 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import alias_generators
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
class BaseModelWithConfig(BaseModel):
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(
extra="allow",
alias_generator=alias_generators.to_camel,
populate_by_name=True,
)
"""The pydantic model config."""
+3 -4
View File
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
from .auth_credential import AuthCredential
from .auth_credential import BaseModelWithConfig
from .auth_schemes import AuthScheme
class AuthConfig(BaseModel):
class AuthConfig(BaseModelWithConfig):
"""The auth config sent by tool asking client to collect auth credentials and
adk and client will help to fill in the response
@@ -45,7 +44,7 @@ class AuthConfig(BaseModel):
this field"""
class AuthToolArguments(BaseModel):
class AuthToolArguments(BaseModelWithConfig):
"""the arguments for the special long running function tool that is used to
request end user credentials.

Some files were not shown because too many files have changed in this diff Show More