You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add a handwritten tool for Cloud Pub/Sub
Merge https://github.com/google/adk-python/pull/3865 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3865 from kamalaboulhosn:main 37a38a4dcadb04a0af0ec584e6f611204a63cd2a PiperOrigin-RevId: 845345128
This commit is contained in:
committed by
Copybara-Service
parent
7e6ef71eec
commit
b6f6dcbeb4
@@ -0,0 +1,88 @@
|
||||
# Pub/Sub Tools Sample
|
||||
|
||||
## Introduction
|
||||
|
||||
This sample agent demonstrates the Pub/Sub first-party tools in ADK,
|
||||
distributed via the `google.adk.tools.pubsub` module. These tools include:
|
||||
|
||||
1. `publish_message`
|
||||
|
||||
Publishes a message to a Pub/Sub topic.
|
||||
|
||||
2. `pull_messages`
|
||||
|
||||
Pulls messages from a Pub/Sub subscription.
|
||||
|
||||
3. `acknowledge_messages`
|
||||
|
||||
Acknowledges messages on a Pub/Sub subscription.
|
||||
|
||||
## How to use
|
||||
|
||||
Set up environment variables in your `.env` file for using
|
||||
[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio)
|
||||
or
|
||||
[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai)
|
||||
for the LLM service for your agent. For example, for using Google AI Studio you
|
||||
would set:
|
||||
|
||||
* GOOGLE_GENAI_USE_VERTEXAI=FALSE
|
||||
* GOOGLE_API_KEY={your api key}
|
||||
|
||||
### With Application Default Credentials
|
||||
|
||||
This mode is useful for quick development when the agent builder is the only
|
||||
user interacting with the agent. The tools are run with these credentials.
|
||||
|
||||
1. Create application default credentials on the machine where the agent would
|
||||
be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc.
|
||||
|
||||
1. Set `CREDENTIALS_TYPE=None` in `agent.py`
|
||||
|
||||
1. Run the agent
|
||||
|
||||
### With Service Account Keys
|
||||
|
||||
This mode is useful for quick development when the agent builder wants to run
|
||||
the agent with service account credentials. The tools are run with these
|
||||
credentials.
|
||||
|
||||
1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
|
||||
|
||||
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py`
|
||||
|
||||
1. Download the key file and replace `"service_account_key.json"` with the path
|
||||
|
||||
1. Run the agent
|
||||
|
||||
### With Interactive OAuth
|
||||
|
||||
1. Follow
|
||||
https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name.
|
||||
to get your client id and client secret. Be sure to choose "web" as your client
|
||||
type.
|
||||
|
||||
1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/pubsub".
|
||||
|
||||
1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs".
|
||||
|
||||
Note: localhost here is just a hostname that you use to access the dev ui,
|
||||
replace it with the actual hostname you use to access the dev ui.
|
||||
|
||||
1. For 1st run, allow popup for localhost in Chrome.
|
||||
|
||||
1. Configure your `.env` file to add two more variables before running the agent:
|
||||
|
||||
* OAUTH_CLIENT_ID={your client id}
|
||||
* OAUTH_CLIENT_SECRET={your client secret}
|
||||
|
||||
Note: don't create a separate .env, instead put it to the same .env file that
|
||||
stores your Vertex AI or Dev ML credentials
|
||||
|
||||
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent
|
||||
|
||||
## Sample prompts
|
||||
|
||||
* publish 'Hello World' to 'my-topic'
|
||||
* pull messages from 'my-subscription'
|
||||
* acknowledge message 'ack-id' from 'my-subscription'
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.pubsub.config import PubSubToolConfig
|
||||
from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig
|
||||
from google.adk.tools.pubsub.pubsub_toolset import PubSubToolset
|
||||
import google.auth
|
||||
|
||||
# Define the desired credential type.
|
||||
# By default use Application Default Credentials (ADC) from the local
|
||||
# environment, which can be set up by following
|
||||
# https://cloud.google.com/docs/authentication/provide-credentials-adc.
|
||||
CREDENTIALS_TYPE = None
|
||||
|
||||
# Define an appropriate application name
|
||||
PUBSUB_AGENT_NAME = "adk_sample_pubsub_agent"
|
||||
|
||||
|
||||
# Define Pub/Sub tool config.
|
||||
# You can optionally set the project_id here, or let the agent infer it from context/user input.
|
||||
tool_config = PubSubToolConfig(project_id=os.getenv("GOOGLE_CLOUD_PROJECT"))
|
||||
|
||||
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
|
||||
# Initialize the tools to do interactive OAuth
|
||||
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
|
||||
# must be set
|
||||
credentials_config = PubSubCredentialsConfig(
|
||||
client_id=os.getenv("OAUTH_CLIENT_ID"),
|
||||
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
|
||||
)
|
||||
elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT:
|
||||
# Initialize the tools to use the credentials in the service account key.
|
||||
# If this flow is enabled, make sure to replace the file path with your own
|
||||
# service account key file
|
||||
# https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys
|
||||
creds, _ = google.auth.load_credentials_from_file("service_account_key.json")
|
||||
credentials_config = PubSubCredentialsConfig(credentials=creds)
|
||||
else:
|
||||
# Initialize the tools to use the application default credentials.
|
||||
# https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
application_default_credentials, _ = google.auth.default()
|
||||
credentials_config = PubSubCredentialsConfig(
|
||||
credentials=application_default_credentials
|
||||
)
|
||||
|
||||
pubsub_toolset = PubSubToolset(
|
||||
credentials_config=credentials_config, pubsub_tool_config=tool_config
|
||||
)
|
||||
|
||||
# The variable name `root_agent` determines what your root agent is for the
|
||||
# debug CLI
|
||||
root_agent = LlmAgent(
|
||||
model="gemini-2.5-flash",
|
||||
name=PUBSUB_AGENT_NAME,
|
||||
description=(
|
||||
"Agent to publish, pull, and acknowledge messages from Google Cloud"
|
||||
" Pub/Sub."
|
||||
),
|
||||
instruction=textwrap.dedent("""\
|
||||
You are a cloud engineer agent with access to Google Cloud Pub/Sub tools.
|
||||
You can publish messages to topics, pull messages from subscriptions, and acknowledge messages.
|
||||
"""),
|
||||
tools=[pubsub_toolset],
|
||||
)
|
||||
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"google-cloud-bigquery>=2.2.0",
|
||||
"google-cloud-bigtable>=2.32.0", # For Bigtable database
|
||||
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
|
||||
"google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool
|
||||
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
|
||||
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
|
||||
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
|
||||
|
||||
@@ -32,6 +32,7 @@ class FeatureName(str, Enum):
|
||||
GOOGLE_TOOL = "GOOGLE_TOOL"
|
||||
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
|
||||
PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING"
|
||||
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
|
||||
SPANNER_TOOLSET = "SPANNER_TOOLSET"
|
||||
SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS"
|
||||
|
||||
@@ -90,6 +91,9 @@ _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
|
||||
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
|
||||
FeatureStage.WIP, default_on=False
|
||||
),
|
||||
FeatureName.PUBSUB_TOOLSET: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
FeatureName.SPANNER_TOOLSET: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pub/Sub Tools (Experimental).
|
||||
|
||||
Pub/Sub Tools under this module are hand crafted and customized while the tools
|
||||
under google.adk.tools.google_api_tool are auto generated based on API
|
||||
definition. The rationales to have customized tool are:
|
||||
|
||||
1. Better handling of base64 encoding for published messages.
|
||||
2. A richer subscribe-side API that reflects how users may want to pull/ack
|
||||
messages.
|
||||
"""
|
||||
|
||||
from .config import PubSubToolConfig
|
||||
from .pubsub_credentials import PubSubCredentialsConfig
|
||||
from .pubsub_toolset import PubSubToolset
|
||||
|
||||
__all__ = ["PubSubCredentialsConfig", "PubSubToolConfig", "PubSubToolset"]
|
||||
@@ -0,0 +1,165 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import pubsub_v1
|
||||
from google.cloud.pubsub_v1.types import BatchSettings
|
||||
|
||||
from ... import version
|
||||
|
||||
USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}"
|
||||
|
||||
_CACHE_TTL = 1800 # 30 minutes
|
||||
|
||||
_publisher_client_cache = {}
|
||||
_publisher_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_publisher_client(
|
||||
*,
|
||||
credentials: Credentials,
|
||||
user_agent: str | list[str] | None = None,
|
||||
publisher_options: pubsub_v1.types.PublisherOptions | None = None,
|
||||
) -> pubsub_v1.PublisherClient:
|
||||
"""Get a Pub/Sub Publisher client.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for the request.
|
||||
user_agent: The user agent to use for the request.
|
||||
publisher_options: The publisher options to use for the request.
|
||||
|
||||
Returns:
|
||||
A Pub/Sub Publisher client.
|
||||
"""
|
||||
global _publisher_client_cache
|
||||
current_time = time.time()
|
||||
|
||||
user_agents_key = None
|
||||
if user_agent:
|
||||
if isinstance(user_agent, str):
|
||||
user_agents_key = (user_agent,)
|
||||
else:
|
||||
user_agents_key = tuple(user_agent)
|
||||
|
||||
# Use object identity for credentials and publisher_options as they are not hashable
|
||||
key = (id(credentials), user_agents_key, id(publisher_options))
|
||||
|
||||
with _publisher_client_lock:
|
||||
if key in _publisher_client_cache:
|
||||
client, expiration = _publisher_client_cache[key]
|
||||
if expiration > current_time:
|
||||
return client
|
||||
|
||||
user_agents = [USER_AGENT]
|
||||
if user_agent:
|
||||
if isinstance(user_agent, str):
|
||||
user_agents.append(user_agent)
|
||||
else:
|
||||
user_agents.extend(ua for ua in user_agent if ua)
|
||||
|
||||
client_info = ClientInfo(user_agent=" ".join(user_agents))
|
||||
|
||||
# Since we synchronously publish messages, we want to disable batching to
|
||||
# remove any delay.
|
||||
custom_batch_settings = BatchSettings(max_messages=1)
|
||||
publisher_client = pubsub_v1.PublisherClient(
|
||||
credentials=credentials,
|
||||
client_info=client_info,
|
||||
publisher_options=publisher_options,
|
||||
batch_settings=custom_batch_settings,
|
||||
)
|
||||
|
||||
_publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL)
|
||||
|
||||
return publisher_client
|
||||
|
||||
|
||||
_subscriber_client_cache = {}
|
||||
_subscriber_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_subscriber_client(
|
||||
*,
|
||||
credentials: Credentials,
|
||||
user_agent: str | list[str] | None = None,
|
||||
) -> pubsub_v1.SubscriberClient:
|
||||
"""Get a Pub/Sub Subscriber client.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for the request.
|
||||
user_agent: The user agent to use for the request.
|
||||
|
||||
Returns:
|
||||
A Pub/Sub Subscriber client.
|
||||
"""
|
||||
global _subscriber_client_cache
|
||||
current_time = time.time()
|
||||
|
||||
user_agents_key = None
|
||||
if user_agent:
|
||||
if isinstance(user_agent, str):
|
||||
user_agents_key = (user_agent,)
|
||||
else:
|
||||
user_agents_key = tuple(user_agent)
|
||||
|
||||
# Use object identity for credentials as they are not hashable
|
||||
key = (id(credentials), user_agents_key)
|
||||
|
||||
with _subscriber_client_lock:
|
||||
if key in _subscriber_client_cache:
|
||||
client, expiration = _subscriber_client_cache[key]
|
||||
if expiration > current_time:
|
||||
return client
|
||||
|
||||
user_agents = [USER_AGENT]
|
||||
if user_agent:
|
||||
if isinstance(user_agent, str):
|
||||
user_agents.append(user_agent)
|
||||
else:
|
||||
user_agents.extend(ua for ua in user_agent if ua)
|
||||
|
||||
client_info = ClientInfo(user_agent=" ".join(user_agents))
|
||||
|
||||
subscriber_client = pubsub_v1.SubscriberClient(
|
||||
credentials=credentials,
|
||||
client_info=client_info,
|
||||
)
|
||||
|
||||
_subscriber_client_cache[key] = (
|
||||
subscriber_client,
|
||||
current_time + _CACHE_TTL,
|
||||
)
|
||||
|
||||
return subscriber_client
|
||||
|
||||
|
||||
def cleanup_clients():
|
||||
"""Clean up all cached Pub/Sub clients."""
|
||||
global _publisher_client_cache, _subscriber_client_cache
|
||||
|
||||
with _publisher_client_lock:
|
||||
for client, _ in _publisher_client_cache.values():
|
||||
client.transport.close()
|
||||
_publisher_client_cache.clear()
|
||||
|
||||
with _subscriber_client_lock:
|
||||
for client, _ in _subscriber_client_cache.values():
|
||||
client.close()
|
||||
_subscriber_client_cache.clear()
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ...utils.feature_decorator import experimental
|
||||
|
||||
|
||||
@experimental('Config defaults may have breaking change in the future.')
|
||||
class PubSubToolConfig(BaseModel):
|
||||
"""Configuration for Pub/Sub tools."""
|
||||
|
||||
# Forbid any fields not defined in the model
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
project_id: str | None = None
|
||||
"""GCP project ID to use for the Pub/Sub operations.
|
||||
|
||||
If not set, the project ID will be inferred from the environment or
|
||||
credentials.
|
||||
"""
|
||||
@@ -0,0 +1,187 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import pubsub_v1
|
||||
|
||||
from . import client
|
||||
from .config import PubSubToolConfig
|
||||
|
||||
|
||||
def publish_message(
|
||||
topic_name: str,
|
||||
message: str,
|
||||
credentials: Credentials,
|
||||
settings: PubSubToolConfig,
|
||||
attributes: Optional[dict[str, str]] = None,
|
||||
ordering_key: str = "",
|
||||
) -> dict:
|
||||
"""Publish a message to a Pub/Sub topic.
|
||||
|
||||
Args:
|
||||
topic_name (str): The Pub/Sub topic name (e.g.
|
||||
projects/my-project/topics/my-topic).
|
||||
message (str): The message content to publish.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
settings (PubSubToolConfig): The Pub/Sub tool settings.
|
||||
attributes (Optional[dict[str, str]]): Attributes to attach to the message.
|
||||
ordering_key (str): Ordering key for the message.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with the message_id of the published message.
|
||||
"""
|
||||
try:
|
||||
publisher_options = pubsub_v1.types.PublisherOptions(
|
||||
enable_message_ordering=bool(ordering_key)
|
||||
)
|
||||
publisher_client = client.get_publisher_client(
|
||||
credentials=credentials,
|
||||
user_agent=[settings.project_id, "publish_message"],
|
||||
publisher_options=publisher_options,
|
||||
)
|
||||
|
||||
message_bytes = message.encode("utf-8")
|
||||
future = publisher_client.publish(
|
||||
topic_name,
|
||||
data=message_bytes,
|
||||
ordering_key=ordering_key,
|
||||
**(attributes or {}),
|
||||
)
|
||||
|
||||
return {"message_id": future.result()}
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": (
|
||||
f"Failed to publish message to topic '{topic_name}': {repr(ex)}"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _decode_message_data(data: bytes) -> str:
|
||||
"""Decodes message data, trying UTF-8 and falling back to base64."""
|
||||
try:
|
||||
return data.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# If UTF-8 decoding fails, encode as base64 string
|
||||
return base64.b64encode(data).decode("ascii")
|
||||
|
||||
|
||||
def pull_messages(
|
||||
subscription_name: str,
|
||||
credentials: Credentials,
|
||||
settings: PubSubToolConfig,
|
||||
*,
|
||||
max_messages: int = 1,
|
||||
auto_ack: bool = False,
|
||||
) -> dict:
|
||||
"""Pull messages from a Pub/Sub subscription.
|
||||
|
||||
Args:
|
||||
subscription_name (str): The Pub/Sub subscription name (e.g.
|
||||
projects/my-project/subscriptions/my-sub).
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
settings (PubSubToolConfig): The Pub/Sub tool settings.
|
||||
max_messages (int): The maximum number of messages to pull. Defaults to 1.
|
||||
auto_ack (bool): Whether to automatically acknowledge the messages.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with the list of pulled messages.
|
||||
"""
|
||||
try:
|
||||
subscriber_client = client.get_subscriber_client(
|
||||
credentials=credentials,
|
||||
user_agent=[settings.project_id, "pull_messages"],
|
||||
)
|
||||
|
||||
response = subscriber_client.pull(
|
||||
subscription=subscription_name,
|
||||
max_messages=max_messages,
|
||||
)
|
||||
|
||||
messages = []
|
||||
ack_ids = []
|
||||
for received_message in response.received_messages:
|
||||
message_data = _decode_message_data(received_message.message.data)
|
||||
messages.append({
|
||||
"message_id": received_message.message.message_id,
|
||||
"data": message_data,
|
||||
"attributes": dict(received_message.message.attributes),
|
||||
"ordering_key": received_message.message.ordering_key,
|
||||
"publish_time": received_message.message.publish_time.rfc3339(),
|
||||
"ack_id": received_message.ack_id,
|
||||
})
|
||||
ack_ids.append(received_message.ack_id)
|
||||
|
||||
if auto_ack and ack_ids:
|
||||
subscriber_client.acknowledge(
|
||||
subscription=subscription_name,
|
||||
ack_ids=ack_ids,
|
||||
)
|
||||
|
||||
return {"messages": messages}
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": (
|
||||
f"Failed to pull messages from subscription '{subscription_name}':"
|
||||
f" {repr(ex)}"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def acknowledge_messages(
|
||||
subscription_name: str,
|
||||
ack_ids: list[str],
|
||||
credentials: Credentials,
|
||||
settings: PubSubToolConfig,
|
||||
) -> dict:
|
||||
"""Acknowledge messages on a Pub/Sub subscription.
|
||||
|
||||
Args:
|
||||
subscription_name (str): The Pub/Sub subscription name (e.g.
|
||||
projects/my-project/subscriptions/my-sub).
|
||||
ack_ids (list[str]): List of acknowledgment IDs to acknowledge.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
settings (PubSubToolConfig): The Pub/Sub tool settings.
|
||||
|
||||
Returns:
|
||||
dict: Status of the operation.
|
||||
"""
|
||||
try:
|
||||
subscriber_client = client.get_subscriber_client(
|
||||
credentials=credentials,
|
||||
user_agent=[settings.project_id, "acknowledge_messages"],
|
||||
)
|
||||
|
||||
subscriber_client.acknowledge(
|
||||
subscription=subscription_name,
|
||||
ack_ids=ack_ids,
|
||||
)
|
||||
|
||||
return {"status": "SUCCESS"}
|
||||
except Exception as ex:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": (
|
||||
"Failed to acknowledge messages on subscription"
|
||||
f" '{subscription_name}': {repr(ex)}"
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from ...features import experimental
|
||||
from ...features import FeatureName
|
||||
from .._google_credentials import BaseGoogleCredentialsConfig
|
||||
|
||||
PUBSUB_TOKEN_CACHE_KEY = "pubsub_token_cache"
|
||||
PUBSUB_DEFAULT_SCOPE = ("https://www.googleapis.com/auth/pubsub",)
|
||||
|
||||
|
||||
@experimental(FeatureName.GOOGLE_CREDENTIALS_CONFIG)
|
||||
class PubSubCredentialsConfig(BaseGoogleCredentialsConfig):
|
||||
"""Pub/Sub Credentials Configuration for Google API tools (Experimental).
|
||||
|
||||
Please do not use this in production, as it may be deprecated later.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def __post_init__(self) -> PubSubCredentialsConfig:
|
||||
"""Populate default scope if scopes is None."""
|
||||
super().__post_init__()
|
||||
|
||||
if not self.scopes:
|
||||
self.scopes = PUBSUB_DEFAULT_SCOPE
|
||||
|
||||
# Set the token cache key
|
||||
self._token_cache_key = PUBSUB_TOKEN_CACHE_KEY
|
||||
|
||||
return self
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from typing_extensions import override
|
||||
|
||||
from . import client
|
||||
from . import message_tool
|
||||
from ...features import experimental
|
||||
from ...features import FeatureName
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.base_toolset import BaseToolset
|
||||
from ...tools.base_toolset import ToolPredicate
|
||||
from ...tools.google_tool import GoogleTool
|
||||
from .config import PubSubToolConfig
|
||||
from .pubsub_credentials import PubSubCredentialsConfig
|
||||
|
||||
|
||||
@experimental(FeatureName.PUBSUB_TOOLSET)
|
||||
class PubSubToolset(BaseToolset):
|
||||
"""Pub/Sub Toolset contains tools for interacting with Pub/Sub topics and subscriptions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_filter: ToolPredicate | list[str] | None = None,
|
||||
credentials_config: PubSubCredentialsConfig | None = None,
|
||||
pubsub_tool_config: PubSubToolConfig | None = None,
|
||||
):
|
||||
"""Initializes the PubSubToolset.
|
||||
|
||||
Args:
|
||||
tool_filter: A predicate or list of tool names to filter the tools in
|
||||
the toolset. If None, all tools are included.
|
||||
credentials_config: The credentials configuration to use for
|
||||
authenticating with Google Cloud.
|
||||
pubsub_tool_config: The configuration for the Pub/Sub tools.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
self._credentials_config = credentials_config
|
||||
self._tool_settings = (
|
||||
pubsub_tool_config if pubsub_tool_config else PubSubToolConfig()
|
||||
)
|
||||
|
||||
def _is_tool_selected(
|
||||
self, tool: BaseTool, readonly_context: ReadonlyContext
|
||||
) -> bool:
|
||||
if self.tool_filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(self.tool_filter, ToolPredicate):
|
||||
return self.tool_filter(tool, readonly_context)
|
||||
|
||||
if isinstance(self.tool_filter, list):
|
||||
return tool.name in self.tool_filter
|
||||
|
||||
return False
|
||||
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: ReadonlyContext | None = None
|
||||
) -> list[BaseTool]:
|
||||
"""Get tools from the toolset."""
|
||||
all_tools = [
|
||||
GoogleTool(
|
||||
func=func,
|
||||
credentials_config=self._credentials_config,
|
||||
tool_settings=self._tool_settings,
|
||||
)
|
||||
for func in [
|
||||
message_tool.publish_message,
|
||||
message_tool.pull_messages,
|
||||
message_tool.acknowledge_messages,
|
||||
]
|
||||
]
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if self._is_tool_selected(tool, readonly_context)
|
||||
]
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
"""Clean up resources used by the toolset."""
|
||||
client.cleanup_clients()
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.pubsub import client
|
||||
from google.cloud import pubsub_v1
|
||||
from google.oauth2.credentials import Credentials
|
||||
import pytest
|
||||
|
||||
# Save original Pub/Sub classes before patching.
|
||||
# This is necessary because create_autospec cannot be used on a mock object,
|
||||
# and mock.patch.object(..., autospec=True) replaces the class with a mock.
|
||||
# We need the original class to create spec'd mocks in side_effect.
|
||||
ORIG_PUBLISHER = pubsub_v1.PublisherClient
|
||||
ORIG_SUBSCRIBER = pubsub_v1.SubscriberClient
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_pubsub_clients():
|
||||
"""Automatically clean up Pub/Sub client caches after each test.
|
||||
|
||||
This fixture runs automatically for all tests in this file,
|
||||
ensuring that client caches are cleared between tests to prevent
|
||||
state leakage and ensure test isolation.
|
||||
"""
|
||||
yield
|
||||
client.cleanup_clients()
|
||||
|
||||
|
||||
@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True)
|
||||
def test_get_publisher_client(mock_publisher_client):
|
||||
"""Test get_publisher_client factory."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
client.get_publisher_client(credentials=mock_creds)
|
||||
|
||||
mock_publisher_client.assert_called_once()
|
||||
_, kwargs = mock_publisher_client.call_args
|
||||
assert kwargs["credentials"] == mock_creds
|
||||
assert "client_info" in kwargs
|
||||
assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings)
|
||||
assert kwargs["batch_settings"].max_messages == 1
|
||||
|
||||
|
||||
@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True)
|
||||
def test_get_publisher_client_with_options(mock_publisher_client):
|
||||
"""Test get_publisher_client factory with options."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
mock_options = mock.create_autospec(
|
||||
pubsub_v1.types.PublisherOptions, instance=True, spec_set=True
|
||||
)
|
||||
client.get_publisher_client(
|
||||
credentials=mock_creds, publisher_options=mock_options
|
||||
)
|
||||
|
||||
mock_publisher_client.assert_called_once()
|
||||
_, kwargs = mock_publisher_client.call_args
|
||||
assert kwargs["credentials"] == mock_creds
|
||||
assert kwargs["publisher_options"] == mock_options
|
||||
assert "client_info" in kwargs
|
||||
assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings)
|
||||
assert kwargs["batch_settings"].max_messages == 1
|
||||
|
||||
|
||||
@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True)
|
||||
def test_get_publisher_client_caching(mock_publisher_client):
|
||||
"""Test get_publisher_client caching behavior."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
mock_publisher_client.side_effect = [
|
||||
mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True),
|
||||
mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True),
|
||||
]
|
||||
|
||||
# First call - should create client
|
||||
client1 = client.get_publisher_client(credentials=mock_creds)
|
||||
mock_publisher_client.assert_called_once()
|
||||
|
||||
# Second call with same args - should return cached client
|
||||
client2 = client.get_publisher_client(credentials=mock_creds)
|
||||
assert client1 is client2
|
||||
mock_publisher_client.assert_called_once() # Still called only once
|
||||
|
||||
# Call with different args - should create new client
|
||||
mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
client3 = client.get_publisher_client(credentials=mock_creds2)
|
||||
assert client3 is not client1
|
||||
assert mock_publisher_client.call_count == 2
|
||||
|
||||
|
||||
@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True)
|
||||
def test_get_subscriber_client(mock_subscriber_client):
|
||||
"""Test get_subscriber_client factory."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
client.get_subscriber_client(credentials=mock_creds)
|
||||
|
||||
mock_subscriber_client.assert_called_once()
|
||||
_, kwargs = mock_subscriber_client.call_args
|
||||
assert kwargs["credentials"] == mock_creds
|
||||
assert "client_info" in kwargs
|
||||
|
||||
|
||||
@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True)
|
||||
def test_get_subscriber_client_caching(mock_subscriber_client):
|
||||
"""Test get_subscriber_client caching behavior."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
mock_subscriber_client.side_effect = [
|
||||
mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True),
|
||||
mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True),
|
||||
]
|
||||
|
||||
# First call - should create client
|
||||
client1 = client.get_subscriber_client(credentials=mock_creds)
|
||||
mock_subscriber_client.assert_called_once()
|
||||
|
||||
# Second call with same args - should return cached client
|
||||
client2 = client.get_subscriber_client(credentials=mock_creds)
|
||||
assert client1 is client2
|
||||
mock_subscriber_client.assert_called_once() # Still called only once
|
||||
|
||||
# Call with different args - should create new client
|
||||
mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True)
|
||||
client3 = client.get_subscriber_client(credentials=mock_creds2)
|
||||
assert client3 is not client1
|
||||
assert mock_subscriber_client.call_count == 2
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.tools.pubsub.config import PubSubToolConfig
|
||||
|
||||
|
||||
def test_pubsub_tool_config_init():
|
||||
"""Test PubSubToolConfig initialization."""
|
||||
config = PubSubToolConfig(project_id="my-project")
|
||||
assert config.project_id == "my-project"
|
||||
|
||||
|
||||
def test_pubsub_tool_config_default():
|
||||
"""Test PubSubToolConfig default initialization."""
|
||||
config = PubSubToolConfig()
|
||||
assert config.project_id is None
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.pubsub.pubsub_credentials import PUBSUB_DEFAULT_SCOPE
|
||||
from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig
|
||||
from google.auth.credentials import Credentials
|
||||
import google.oauth2.credentials
|
||||
import pytest
|
||||
|
||||
"""Test suite for PubSub credentials configuration validation.
|
||||
|
||||
This class tests the credential configuration logic that ensures
|
||||
either existing credentials or client ID/secret pairs are provided.
|
||||
"""
|
||||
|
||||
|
||||
def test_pubsub_credentials_config_client_id_secret():
|
||||
"""Test PubSubCredentialsConfig with client_id and client_secret.
|
||||
|
||||
Ensures that when client_id and client_secret are provided, the config
|
||||
object is created with the correct attributes.
|
||||
"""
|
||||
config = PubSubCredentialsConfig(client_id="abc", client_secret="def")
|
||||
assert config.client_id == "abc"
|
||||
assert config.client_secret == "def"
|
||||
assert config.scopes == PUBSUB_DEFAULT_SCOPE
|
||||
assert config.credentials is None
|
||||
|
||||
|
||||
def test_pubsub_credentials_config_existing_creds():
|
||||
"""Test PubSubCredentialsConfig with existing generic credentials.
|
||||
|
||||
Ensures that when a generic Credentials object is provided, it is
|
||||
stored correctly.
|
||||
"""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True)
|
||||
config = PubSubCredentialsConfig(credentials=mock_creds)
|
||||
assert config.credentials == mock_creds
|
||||
assert config.client_id is None
|
||||
assert config.client_secret is None
|
||||
|
||||
|
||||
def test_pubsub_credentials_config_oauth2_creds():
|
||||
"""Test PubSubCredentialsConfig with existing OAuth2 credentials.
|
||||
|
||||
Ensures that when a google.oauth2.credentials.Credentials object is
|
||||
provided, the client_id, client_secret, and scopes are extracted
|
||||
from the credentials object.
|
||||
"""
|
||||
mock_creds = mock.create_autospec(
|
||||
google.oauth2.credentials.Credentials, instance=True
|
||||
)
|
||||
mock_creds.client_id = "oauth_client_id"
|
||||
mock_creds.client_secret = "oauth_client_secret"
|
||||
mock_creds.scopes = ["fake_scope"]
|
||||
config = PubSubCredentialsConfig(credentials=mock_creds)
|
||||
assert config.client_id == "oauth_client_id"
|
||||
assert config.client_secret == "oauth_client_secret"
|
||||
assert config.scopes == ["fake_scope"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"credentials, client_id, client_secret",
|
||||
[
|
||||
# No arguments provided
|
||||
(None, None, None),
|
||||
# Only client_id is provided
|
||||
(None, "abc", None),
|
||||
],
|
||||
)
|
||||
def test_pubsub_credentials_config_validation_errors(
|
||||
credentials, client_id, client_secret
|
||||
):
|
||||
"""Test PubSubCredentialsConfig validation errors.
|
||||
|
||||
Ensures that ValueError is raised when invalid combinations of credentials
|
||||
and client ID/secret are provided.
|
||||
|
||||
Args:
|
||||
credentials: The credentials object to pass.
|
||||
client_id: The client ID to pass.
|
||||
client_secret: The client secret to pass.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Must provide either credentials or client_id and client_secret pair."
|
||||
),
|
||||
):
|
||||
PubSubCredentialsConfig(
|
||||
credentials=credentials,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
|
||||
def test_pubsub_credentials_config_both_credentials_and_client_provided():
|
||||
"""Test PubSubCredentialsConfig validation errors.
|
||||
|
||||
Ensures that ValueError is raised when invalid combinations of credentials
|
||||
and client ID/secret are provided.
|
||||
|
||||
Args:
|
||||
credentials: The credentials object to pass.
|
||||
client_id: The client ID to pass.
|
||||
client_secret: The client secret to pass.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Cannot provide both existing credentials and"
|
||||
" client_id/client_secret/scopes."
|
||||
),
|
||||
):
|
||||
PubSubCredentialsConfig(
|
||||
credentials=mock.create_autospec(Credentials, instance=True),
|
||||
client_id="abc",
|
||||
client_secret="def",
|
||||
)
|
||||
@@ -0,0 +1,330 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.pubsub import client as pubsub_client_lib
|
||||
from google.adk.tools.pubsub import message_tool
|
||||
from google.adk.tools.pubsub.config import PubSubToolConfig
|
||||
from google.api_core import future
|
||||
from google.cloud import pubsub_v1
|
||||
from google.cloud.pubsub_v1 import types
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google.protobuf import timestamp_pb2
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
|
||||
def test_publish_message(mock_get_publisher_client, mock_publish):
|
||||
"""Test publish_message tool invocation."""
|
||||
topic_name = "projects/my_project_id/topics/my_topic"
|
||||
message = "Hello World"
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_publisher_client = mock.create_autospec(
|
||||
pubsub_v1.PublisherClient, instance=True
|
||||
)
|
||||
mock_get_publisher_client.return_value = mock_publisher_client
|
||||
|
||||
mock_future = mock.create_autospec(future.Future, instance=True)
|
||||
mock_future.result.return_value = "message_id"
|
||||
mock_publisher_client.publish.return_value = mock_future
|
||||
|
||||
result = message_tool.publish_message(
|
||||
topic_name, message, mock_credentials, tool_settings
|
||||
)
|
||||
|
||||
assert result["message_id"] == "message_id"
|
||||
mock_get_publisher_client.assert_called_once()
|
||||
mock_publisher_client.publish.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
|
||||
def test_publish_message_with_ordering_key(
|
||||
mock_get_publisher_client, mock_publish
|
||||
):
|
||||
"""Test publish_message tool invocation with ordering_key."""
|
||||
topic_name = "projects/my_project_id/topics/my_topic"
|
||||
message = "Hello World"
|
||||
ordering_key = "key1"
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_publisher_client = mock.create_autospec(
|
||||
pubsub_v1.PublisherClient, instance=True
|
||||
)
|
||||
mock_get_publisher_client.return_value = mock_publisher_client
|
||||
|
||||
mock_future = mock.create_autospec(future.Future, instance=True)
|
||||
mock_future.result.return_value = "message_id"
|
||||
mock_publisher_client.publish.return_value = mock_future
|
||||
|
||||
result = message_tool.publish_message(
|
||||
topic_name,
|
||||
message,
|
||||
mock_credentials,
|
||||
tool_settings,
|
||||
ordering_key=ordering_key,
|
||||
)
|
||||
|
||||
assert result["message_id"] == "message_id"
|
||||
mock_get_publisher_client.assert_called_once()
|
||||
_, kwargs = mock_get_publisher_client.call_args
|
||||
assert kwargs["publisher_options"].enable_message_ordering is True
|
||||
|
||||
mock_publisher_client.publish.assert_called_once()
|
||||
|
||||
# Verify ordering_key was passed
|
||||
_, kwargs = mock_publisher_client.publish.call_args
|
||||
assert kwargs["ordering_key"] == ordering_key
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
|
||||
def test_publish_message_with_attributes(
|
||||
mock_get_publisher_client, mock_publish
|
||||
):
|
||||
"""Test publish_message tool invocation with attributes."""
|
||||
topic_name = "projects/my_project_id/topics/my_topic"
|
||||
message = "Hello World"
|
||||
attributes = {"key1": "value1", "key2": "value2"}
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_publisher_client = mock.create_autospec(
|
||||
pubsub_v1.PublisherClient, instance=True
|
||||
)
|
||||
mock_get_publisher_client.return_value = mock_publisher_client
|
||||
|
||||
mock_future = mock.create_autospec(future.Future, instance=True)
|
||||
mock_future.result.return_value = "message_id"
|
||||
mock_publisher_client.publish.return_value = mock_future
|
||||
|
||||
result = message_tool.publish_message(
|
||||
topic_name,
|
||||
message,
|
||||
mock_credentials,
|
||||
tool_settings,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
assert result["message_id"] == "message_id"
|
||||
mock_get_publisher_client.assert_called_once()
|
||||
mock_publisher_client.publish.assert_called_once()
|
||||
|
||||
# Verify attributes were passed
|
||||
_, kwargs = mock_publisher_client.publish.call_args
|
||||
assert kwargs["key1"] == "value1"
|
||||
assert kwargs["key2"] == "value2"
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
|
||||
def test_publish_message_exception(mock_get_publisher_client, mock_publish):
|
||||
"""Test publish_message tool invocation when exception occurs."""
|
||||
topic_name = "projects/my_project_id/topics/my_topic"
|
||||
message = "Hello World"
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_publisher_client = mock.create_autospec(
|
||||
pubsub_v1.PublisherClient, instance=True
|
||||
)
|
||||
mock_get_publisher_client.return_value = mock_publisher_client
|
||||
|
||||
# Simulate an exception during publish
|
||||
mock_publisher_client.publish.side_effect = Exception("Publish failed")
|
||||
|
||||
result = message_tool.publish_message(
|
||||
topic_name,
|
||||
message,
|
||||
mock_credentials,
|
||||
tool_settings,
|
||||
)
|
||||
|
||||
assert result["status"] == "ERROR"
|
||||
assert "Publish failed" in result["error_details"]
|
||||
mock_get_publisher_client.assert_called_once()
|
||||
mock_publisher_client.publish.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
|
||||
def test_pull_messages(mock_get_subscriber_client):
|
||||
"""Test pull_messages tool invocation."""
|
||||
subscription_name = "projects/my_project_id/subscriptions/my_sub"
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_subscriber_client = mock.create_autospec(
|
||||
pubsub_v1.SubscriberClient, instance=True
|
||||
)
|
||||
mock_get_subscriber_client.return_value = mock_subscriber_client
|
||||
|
||||
mock_response = mock.create_autospec(types.PullResponse, instance=True)
|
||||
mock_message = mock.MagicMock()
|
||||
mock_message.message.message_id = "123"
|
||||
mock_message.message.data = b"Hello"
|
||||
mock_message.message.attributes = {"key": "value"}
|
||||
mock_message.message.ordering_key = "ABC"
|
||||
mock_publish_time = mock.MagicMock()
|
||||
mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z"
|
||||
mock_message.message.publish_time = mock_publish_time
|
||||
mock_message.ack_id = "ack_123"
|
||||
mock_response.received_messages = [mock_message]
|
||||
mock_subscriber_client.pull.return_value = mock_response
|
||||
|
||||
result = message_tool.pull_messages(
|
||||
subscription_name, mock_credentials, tool_settings
|
||||
)
|
||||
|
||||
expected_message = {
|
||||
"message_id": "123",
|
||||
"data": "Hello",
|
||||
"attributes": {"key": "value"},
|
||||
"ordering_key": "ABC",
|
||||
"publish_time": "2023-01-01T00:00:00Z",
|
||||
"ack_id": "ack_123",
|
||||
}
|
||||
assert result["messages"] == [expected_message]
|
||||
|
||||
mock_get_subscriber_client.assert_called_once()
|
||||
mock_subscriber_client.pull.assert_called_once_with(
|
||||
subscription=subscription_name, max_messages=1
|
||||
)
|
||||
mock_subscriber_client.acknowledge.assert_not_called()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
|
||||
def test_pull_messages_auto_ack(mock_get_subscriber_client):
|
||||
"""Test pull_messages tool invocation with auto_ack."""
|
||||
subscription_name = "projects/my_project_id/subscriptions/my_sub"
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_subscriber_client = mock.create_autospec(
|
||||
pubsub_v1.SubscriberClient, instance=True
|
||||
)
|
||||
mock_get_subscriber_client.return_value = mock_subscriber_client
|
||||
|
||||
mock_response = mock.create_autospec(types.PullResponse, instance=True)
|
||||
mock_message = mock.MagicMock()
|
||||
mock_message.message.message_id = "123"
|
||||
mock_message.message.data = b"Hello"
|
||||
mock_message.message.attributes = {}
|
||||
mock_publish_time = mock.MagicMock()
|
||||
mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z"
|
||||
mock_message.message.publish_time = mock_publish_time
|
||||
mock_message.ack_id = "ack_123"
|
||||
mock_response.received_messages = [mock_message]
|
||||
mock_subscriber_client.pull.return_value = mock_response
|
||||
|
||||
result = message_tool.pull_messages(
|
||||
subscription_name,
|
||||
mock_credentials,
|
||||
tool_settings,
|
||||
max_messages=5,
|
||||
auto_ack=True,
|
||||
)
|
||||
|
||||
assert len(result["messages"]) == 1
|
||||
mock_get_subscriber_client.assert_called_once()
|
||||
mock_subscriber_client.pull.assert_called_once_with(
|
||||
subscription=subscription_name, max_messages=5
|
||||
)
|
||||
mock_subscriber_client.acknowledge.assert_called_once_with(
|
||||
subscription=subscription_name, ack_ids=["ack_123"]
|
||||
)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
|
||||
def test_pull_messages_exception(mock_get_subscriber_client):
|
||||
"""Test pull_messages tool invocation when exception occurs."""
|
||||
subscription_name = "projects/my_project_id/subscriptions/my_sub"
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_subscriber_client = mock.create_autospec(
|
||||
pubsub_v1.SubscriberClient, instance=True
|
||||
)
|
||||
mock_get_subscriber_client.return_value = mock_subscriber_client
|
||||
|
||||
mock_subscriber_client.pull.side_effect = Exception("Pull failed")
|
||||
|
||||
result = message_tool.pull_messages(
|
||||
subscription_name, mock_credentials, tool_settings
|
||||
)
|
||||
|
||||
assert result["status"] == "ERROR"
|
||||
assert "Pull failed" in result["error_details"]
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
|
||||
def test_acknowledge_messages(mock_get_subscriber_client):
|
||||
"""Test acknowledge_messages tool invocation."""
|
||||
subscription_name = "projects/my_project_id/subscriptions/my_sub"
|
||||
ack_ids = ["ack1", "ack2"]
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_subscriber_client = mock.create_autospec(
|
||||
pubsub_v1.SubscriberClient, instance=True
|
||||
)
|
||||
mock_get_subscriber_client.return_value = mock_subscriber_client
|
||||
|
||||
result = message_tool.acknowledge_messages(
|
||||
subscription_name, ack_ids, mock_credentials, tool_settings
|
||||
)
|
||||
|
||||
assert result["status"] == "SUCCESS"
|
||||
mock_get_subscriber_client.assert_called_once()
|
||||
mock_subscriber_client.acknowledge.assert_called_once_with(
|
||||
subscription=subscription_name, ack_ids=ack_ids
|
||||
)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
|
||||
def test_acknowledge_messages_exception(mock_get_subscriber_client):
|
||||
"""Test acknowledge_messages tool invocation when exception occurs."""
|
||||
subscription_name = "projects/my_project_id/subscriptions/my_sub"
|
||||
ack_ids = ["ack1"]
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_settings = PubSubToolConfig(project_id="my_project_id")
|
||||
|
||||
mock_subscriber_client = mock.create_autospec(
|
||||
pubsub_v1.SubscriberClient, instance=True
|
||||
)
|
||||
mock_get_subscriber_client.return_value = mock_subscriber_client
|
||||
|
||||
mock_subscriber_client.acknowledge.side_effect = Exception("Ack failed")
|
||||
|
||||
result = message_tool.acknowledge_messages(
|
||||
subscription_name, ack_ids, mock_credentials, tool_settings
|
||||
)
|
||||
|
||||
assert result["status"] == "ERROR"
|
||||
assert "Ack failed" in result["error_details"]
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.tools.google_tool import GoogleTool
|
||||
from google.adk.tools.pubsub import PubSubCredentialsConfig
|
||||
from google.adk.tools.pubsub import PubSubToolset
|
||||
from google.adk.tools.pubsub.config import PubSubToolConfig
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pubsub_toolset_tools_default():
|
||||
"""Test default PubSub toolset.
|
||||
|
||||
This test verifies the behavior of the PubSub toolset when no filter is
|
||||
specified.
|
||||
"""
|
||||
credentials_config = PubSubCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
toolset = PubSubToolset(
|
||||
credentials_config=credentials_config, pubsub_tool_config=None
|
||||
)
|
||||
# Verify that the tool config is initialized to default values.
|
||||
assert isinstance(toolset._tool_settings, PubSubToolConfig) # pylint: disable=protected-access
|
||||
assert toolset._tool_settings.__dict__ == PubSubToolConfig().__dict__ # pylint: disable=protected-access
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == 3
|
||||
assert all(isinstance(tool, GoogleTool) for tool in tools)
|
||||
|
||||
expected_tool_names = set([
|
||||
"publish_message",
|
||||
"pull_messages",
|
||||
"acknowledge_messages",
|
||||
])
|
||||
actual_tool_names = {tool.name for tool in tools}
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"selected_tools",
|
||||
[
|
||||
pytest.param([], id="None"),
|
||||
pytest.param(["publish_message"], id="publish"),
|
||||
pytest.param(["pull_messages"], id="pull"),
|
||||
pytest.param(["acknowledge_messages"], id="ack"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_pubsub_toolset_tools_selective(selected_tools):
|
||||
"""Test PubSub toolset with filter.
|
||||
|
||||
This test verifies the behavior of the PubSub toolset when filter is
|
||||
specified. A use case for this would be when the agent builder wants to
|
||||
use only a subset of the tools provided by the toolset.
|
||||
|
||||
Args:
|
||||
selected_tools: The list of tools to select from the toolset.
|
||||
"""
|
||||
credentials_config = PubSubCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
toolset = PubSubToolset(
|
||||
credentials_config=credentials_config, tool_filter=selected_tools
|
||||
)
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == len(selected_tools)
|
||||
assert all(isinstance(tool, GoogleTool) for tool in tools)
|
||||
|
||||
expected_tool_names = set(selected_tools)
|
||||
actual_tool_names = {tool.name for tool in tools}
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("selected_tools", "returned_tools"),
|
||||
[
|
||||
pytest.param(["unknown"], [], id="all-unknown"),
|
||||
pytest.param(
|
||||
["unknown", "publish_message"],
|
||||
["publish_message"],
|
||||
id="mixed-known-unknown",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_pubsub_toolset_unknown_tool(selected_tools, returned_tools):
|
||||
"""Test PubSub toolset with filter.
|
||||
|
||||
This test verifies the behavior of the PubSub toolset when filter is
|
||||
specified with an unknown tool.
|
||||
|
||||
Args:
|
||||
selected_tools: The list of tools to select from the toolset.
|
||||
returned_tools: The list of tools that are expected to be returned.
|
||||
"""
|
||||
credentials_config = PubSubCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
|
||||
toolset = PubSubToolset(
|
||||
credentials_config=credentials_config, tool_filter=selected_tools
|
||||
)
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == len(returned_tools)
|
||||
assert all(isinstance(tool, GoogleTool) for tool in tools)
|
||||
|
||||
expected_tool_names = set(returned_tools)
|
||||
actual_tool_names = {tool.name for tool in tools}
|
||||
assert actual_tool_names == expected_tool_names
|
||||
Reference in New Issue
Block a user