feat: Add a handwritten tool for Cloud Pub/Sub

Merge https://github.com/google/adk-python/pull/3865

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3865 from kamalaboulhosn:main 37a38a4dcadb04a0af0ec584e6f611204a63cd2a
PiperOrigin-RevId: 845345128
This commit is contained in:
Kamal Aboul-Hosn
2025-12-16 10:47:35 -08:00
committed by Copybara-Service
parent 7e6ef71eec
commit b6f6dcbeb4
16 changed files with 1504 additions and 0 deletions
+88
View File
@@ -0,0 +1,88 @@
# Pub/Sub Tools Sample
## Introduction
This sample agent demonstrates the Pub/Sub first-party tools in ADK,
distributed via the `google.adk.tools.pubsub` module. These tools include:
1. `publish_message`
Publishes a message to a Pub/Sub topic.
2. `pull_messages`
Pulls messages from a Pub/Sub subscription.
3. `acknowledge_messages`
Acknowledges messages on a Pub/Sub subscription.
## How to use
Set up environment variables in your `.env` file for using
[Google AI Studio](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-ai-studio)
or
[Google Cloud Vertex AI](https://google.github.io/adk-docs/get-started/quickstart/#gemini---google-cloud-vertex-ai)
for the LLM service for your agent. For example, for using Google AI Studio you
would set:
* GOOGLE_GENAI_USE_VERTEXAI=FALSE
* GOOGLE_API_KEY={your api key}
### With Application Default Credentials
This mode is useful for quick development when the agent builder is the only
user interacting with the agent. The tools are run with these credentials.
1. Create application default credentials on the machine where the agent would
be running by following https://cloud.google.com/docs/authentication/provide-credentials-adc.
1. Set `CREDENTIALS_TYPE=None` in `agent.py`
1. Run the agent
### With Service Account Keys
This mode is useful for quick development when the agent builder wants to run
the agent with service account credentials. The tools are run with these
credentials.
1. Create service account key by following https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.SERVICE_ACCOUNT` in `agent.py`
1. Download the key file and replace `"service_account_key.json"` with the path
1. Run the agent
### With Interactive OAuth
1. Follow
https://developers.google.com/identity/protocols/oauth2#1.-obtain-oauth-2.0-credentials-from-the-dynamic_data.setvar.console_name.
to get your client id and client secret. Be sure to choose "web" as your client
type.
1. Follow https://developers.google.com/workspace/guides/configure-oauth-consent to add scope "https://www.googleapis.com/auth/pubsub".
1. Follow https://developers.google.com/identity/protocols/oauth2/web-server#creatingcred to add http://localhost/dev-ui/ to "Authorized redirect URIs".
Note: localhost here is just a hostname that you use to access the dev ui,
replace it with the actual hostname you use to access the dev ui.
1. For 1st run, allow popup for localhost in Chrome.
1. Configure your `.env` file to add two more variables before running the agent:
* OAUTH_CLIENT_ID={your client id}
* OAUTH_CLIENT_SECRET={your client secret}
Note: don't create a separate .env, instead put it to the same .env file that
stores your Vertex AI or Dev ML credentials
1. Set `CREDENTIALS_TYPE=AuthCredentialTypes.OAUTH2` in `agent.py` and run the agent
## Sample prompts
* publish 'Hello World' to 'my-topic'
* pull messages from 'my-subscription'
* acknowledge message 'ack-id' from 'my-subscription'
+15
View File
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
+80
View File
@@ -0,0 +1,80 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from google.adk.agents.llm_agent import LlmAgent
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.tools.pubsub.config import PubSubToolConfig
from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig
from google.adk.tools.pubsub.pubsub_toolset import PubSubToolset
import google.auth
# Define the desired credential type.
# By default use Application Default Credentials (ADC) from the local
# environment, which can be set up by following
# https://cloud.google.com/docs/authentication/provide-credentials-adc.
CREDENTIALS_TYPE = None
# Define an appropriate application name
PUBSUB_AGENT_NAME = "adk_sample_pubsub_agent"
# Define Pub/Sub tool config.
# You can optionally set the project_id here, or let the agent infer it from context/user input.
tool_config = PubSubToolConfig(project_id=os.getenv("GOOGLE_CLOUD_PROJECT"))
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
# Initialize the tools to do interactive OAuth
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
# must be set
credentials_config = PubSubCredentialsConfig(
client_id=os.getenv("OAUTH_CLIENT_ID"),
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
)
elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT:
# Initialize the tools to use the credentials in the service account key.
# If this flow is enabled, make sure to replace the file path with your own
# service account key file
# https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys
creds, _ = google.auth.load_credentials_from_file("service_account_key.json")
credentials_config = PubSubCredentialsConfig(credentials=creds)
else:
# Initialize the tools to use the application default credentials.
# https://cloud.google.com/docs/authentication/provide-credentials-adc
application_default_credentials, _ = google.auth.default()
credentials_config = PubSubCredentialsConfig(
credentials=application_default_credentials
)
pubsub_toolset = PubSubToolset(
credentials_config=credentials_config, pubsub_tool_config=tool_config
)
# The variable name `root_agent` determines what your root agent is for the
# debug CLI
root_agent = LlmAgent(
model="gemini-2.5-flash",
name=PUBSUB_AGENT_NAME,
description=(
"Agent to publish, pull, and acknowledge messages from Google Cloud"
" Pub/Sub."
),
instruction=textwrap.dedent("""\
You are a cloud engineer agent with access to Google Cloud Pub/Sub tools.
You can publish messages to topics, pull messages from subscriptions, and acknowledge messages.
"""),
tools=[pubsub_toolset],
)
+1
View File
@@ -37,6 +37,7 @@ dependencies = [
"google-cloud-bigquery>=2.2.0",
"google-cloud-bigtable>=2.32.0", # For Bigtable database
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
"google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription
@@ -32,6 +32,7 @@ class FeatureName(str, Enum):
GOOGLE_TOOL = "GOOGLE_TOOL"
JSON_SCHEMA_FOR_FUNC_DECL = "JSON_SCHEMA_FOR_FUNC_DECL"
PROGRESSIVE_SSE_STREAMING = "PROGRESSIVE_SSE_STREAMING"
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
SPANNER_TOOLSET = "SPANNER_TOOLSET"
SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS"
@@ -90,6 +91,9 @@ _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
FeatureName.PROGRESSIVE_SSE_STREAMING: FeatureConfig(
FeatureStage.WIP, default_on=False
),
FeatureName.PUBSUB_TOOLSET: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
FeatureName.SPANNER_TOOLSET: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
+30
View File
@@ -0,0 +1,30 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pub/Sub Tools (Experimental).
Pub/Sub Tools under this module are hand crafted and customized while the tools
under google.adk.tools.google_api_tool are auto generated based on API
definition. The rationales to have customized tool are:
1. Better handling of base64 encoding for published messages.
2. A richer subscribe-side API that reflects how users may want to pull/ack
messages.
"""
from .config import PubSubToolConfig
from .pubsub_credentials import PubSubCredentialsConfig
from .pubsub_toolset import PubSubToolset
__all__ = ["PubSubCredentialsConfig", "PubSubToolConfig", "PubSubToolset"]
+165
View File
@@ -0,0 +1,165 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
import time
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import Credentials
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1.types import BatchSettings
from ... import version
USER_AGENT = f"adk-pubsub-tool google-adk/{version.__version__}"
_CACHE_TTL = 1800 # 30 minutes
_publisher_client_cache = {}
_publisher_client_lock = threading.Lock()
def get_publisher_client(
*,
credentials: Credentials,
user_agent: str | list[str] | None = None,
publisher_options: pubsub_v1.types.PublisherOptions | None = None,
) -> pubsub_v1.PublisherClient:
"""Get a Pub/Sub Publisher client.
Args:
credentials: The credentials to use for the request.
user_agent: The user agent to use for the request.
publisher_options: The publisher options to use for the request.
Returns:
A Pub/Sub Publisher client.
"""
global _publisher_client_cache
current_time = time.time()
user_agents_key = None
if user_agent:
if isinstance(user_agent, str):
user_agents_key = (user_agent,)
else:
user_agents_key = tuple(user_agent)
# Use object identity for credentials and publisher_options as they are not hashable
key = (id(credentials), user_agents_key, id(publisher_options))
with _publisher_client_lock:
if key in _publisher_client_cache:
client, expiration = _publisher_client_cache[key]
if expiration > current_time:
return client
user_agents = [USER_AGENT]
if user_agent:
if isinstance(user_agent, str):
user_agents.append(user_agent)
else:
user_agents.extend(ua for ua in user_agent if ua)
client_info = ClientInfo(user_agent=" ".join(user_agents))
# Since we synchronously publish messages, we want to disable batching to
# remove any delay.
custom_batch_settings = BatchSettings(max_messages=1)
publisher_client = pubsub_v1.PublisherClient(
credentials=credentials,
client_info=client_info,
publisher_options=publisher_options,
batch_settings=custom_batch_settings,
)
_publisher_client_cache[key] = (publisher_client, current_time + _CACHE_TTL)
return publisher_client
_subscriber_client_cache = {}
_subscriber_client_lock = threading.Lock()
def get_subscriber_client(
*,
credentials: Credentials,
user_agent: str | list[str] | None = None,
) -> pubsub_v1.SubscriberClient:
"""Get a Pub/Sub Subscriber client.
Args:
credentials: The credentials to use for the request.
user_agent: The user agent to use for the request.
Returns:
A Pub/Sub Subscriber client.
"""
global _subscriber_client_cache
current_time = time.time()
user_agents_key = None
if user_agent:
if isinstance(user_agent, str):
user_agents_key = (user_agent,)
else:
user_agents_key = tuple(user_agent)
# Use object identity for credentials as they are not hashable
key = (id(credentials), user_agents_key)
with _subscriber_client_lock:
if key in _subscriber_client_cache:
client, expiration = _subscriber_client_cache[key]
if expiration > current_time:
return client
user_agents = [USER_AGENT]
if user_agent:
if isinstance(user_agent, str):
user_agents.append(user_agent)
else:
user_agents.extend(ua for ua in user_agent if ua)
client_info = ClientInfo(user_agent=" ".join(user_agents))
subscriber_client = pubsub_v1.SubscriberClient(
credentials=credentials,
client_info=client_info,
)
_subscriber_client_cache[key] = (
subscriber_client,
current_time + _CACHE_TTL,
)
return subscriber_client
def cleanup_clients():
"""Clean up all cached Pub/Sub clients."""
global _publisher_client_cache, _subscriber_client_cache
with _publisher_client_lock:
for client, _ in _publisher_client_cache.values():
client.transport.close()
_publisher_client_cache.clear()
with _subscriber_client_lock:
for client, _ in _subscriber_client_cache.values():
client.close()
_subscriber_client_cache.clear()
+35
View File
@@ -0,0 +1,35 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from pydantic import BaseModel
from pydantic import ConfigDict
from ...utils.feature_decorator import experimental
@experimental('Config defaults may have breaking change in the future.')
class PubSubToolConfig(BaseModel):
"""Configuration for Pub/Sub tools."""
# Forbid any fields not defined in the model
model_config = ConfigDict(extra='forbid')
project_id: str | None = None
"""GCP project ID to use for the Pub/Sub operations.
If not set, the project ID will be inferred from the environment or
credentials.
"""
+187
View File
@@ -0,0 +1,187 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import base64
from typing import Optional
from google.auth.credentials import Credentials
from google.cloud import pubsub_v1
from . import client
from .config import PubSubToolConfig
def publish_message(
topic_name: str,
message: str,
credentials: Credentials,
settings: PubSubToolConfig,
attributes: Optional[dict[str, str]] = None,
ordering_key: str = "",
) -> dict:
"""Publish a message to a Pub/Sub topic.
Args:
topic_name (str): The Pub/Sub topic name (e.g.
projects/my-project/topics/my-topic).
message (str): The message content to publish.
credentials (Credentials): The credentials to use for the request.
settings (PubSubToolConfig): The Pub/Sub tool settings.
attributes (Optional[dict[str, str]]): Attributes to attach to the message.
ordering_key (str): Ordering key for the message.
Returns:
dict: Dictionary with the message_id of the published message.
"""
try:
publisher_options = pubsub_v1.types.PublisherOptions(
enable_message_ordering=bool(ordering_key)
)
publisher_client = client.get_publisher_client(
credentials=credentials,
user_agent=[settings.project_id, "publish_message"],
publisher_options=publisher_options,
)
message_bytes = message.encode("utf-8")
future = publisher_client.publish(
topic_name,
data=message_bytes,
ordering_key=ordering_key,
**(attributes or {}),
)
return {"message_id": future.result()}
except Exception as ex:
return {
"status": "ERROR",
"error_details": (
f"Failed to publish message to topic '{topic_name}': {repr(ex)}"
),
}
def _decode_message_data(data: bytes) -> str:
"""Decodes message data, trying UTF-8 and falling back to base64."""
try:
return data.decode("utf-8")
except UnicodeDecodeError:
# If UTF-8 decoding fails, encode as base64 string
return base64.b64encode(data).decode("ascii")
def pull_messages(
subscription_name: str,
credentials: Credentials,
settings: PubSubToolConfig,
*,
max_messages: int = 1,
auto_ack: bool = False,
) -> dict:
"""Pull messages from a Pub/Sub subscription.
Args:
subscription_name (str): The Pub/Sub subscription name (e.g.
projects/my-project/subscriptions/my-sub).
credentials (Credentials): The credentials to use for the request.
settings (PubSubToolConfig): The Pub/Sub tool settings.
max_messages (int): The maximum number of messages to pull. Defaults to 1.
auto_ack (bool): Whether to automatically acknowledge the messages.
Defaults to False.
Returns:
dict: Dictionary with the list of pulled messages.
"""
try:
subscriber_client = client.get_subscriber_client(
credentials=credentials,
user_agent=[settings.project_id, "pull_messages"],
)
response = subscriber_client.pull(
subscription=subscription_name,
max_messages=max_messages,
)
messages = []
ack_ids = []
for received_message in response.received_messages:
message_data = _decode_message_data(received_message.message.data)
messages.append({
"message_id": received_message.message.message_id,
"data": message_data,
"attributes": dict(received_message.message.attributes),
"ordering_key": received_message.message.ordering_key,
"publish_time": received_message.message.publish_time.rfc3339(),
"ack_id": received_message.ack_id,
})
ack_ids.append(received_message.ack_id)
if auto_ack and ack_ids:
subscriber_client.acknowledge(
subscription=subscription_name,
ack_ids=ack_ids,
)
return {"messages": messages}
except Exception as ex:
return {
"status": "ERROR",
"error_details": (
f"Failed to pull messages from subscription '{subscription_name}':"
f" {repr(ex)}"
),
}
def acknowledge_messages(
subscription_name: str,
ack_ids: list[str],
credentials: Credentials,
settings: PubSubToolConfig,
) -> dict:
"""Acknowledge messages on a Pub/Sub subscription.
Args:
subscription_name (str): The Pub/Sub subscription name (e.g.
projects/my-project/subscriptions/my-sub).
ack_ids (list[str]): List of acknowledgment IDs to acknowledge.
credentials (Credentials): The credentials to use for the request.
settings (PubSubToolConfig): The Pub/Sub tool settings.
Returns:
dict: Status of the operation.
"""
try:
subscriber_client = client.get_subscriber_client(
credentials=credentials,
user_agent=[settings.project_id, "acknowledge_messages"],
)
subscriber_client.acknowledge(
subscription=subscription_name,
ack_ids=ack_ids,
)
return {"status": "SUCCESS"}
except Exception as ex:
return {
"status": "ERROR",
"error_details": (
"Failed to acknowledge messages on subscription"
f" '{subscription_name}': {repr(ex)}"
),
}
@@ -0,0 +1,45 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from pydantic import model_validator
from ...features import experimental
from ...features import FeatureName
from .._google_credentials import BaseGoogleCredentialsConfig
PUBSUB_TOKEN_CACHE_KEY = "pubsub_token_cache"
PUBSUB_DEFAULT_SCOPE = ("https://www.googleapis.com/auth/pubsub",)
@experimental(FeatureName.GOOGLE_CREDENTIALS_CONFIG)
class PubSubCredentialsConfig(BaseGoogleCredentialsConfig):
"""Pub/Sub Credentials Configuration for Google API tools (Experimental).
Please do not use this in production, as it may be deprecated later.
"""
@model_validator(mode="after")
def __post_init__(self) -> PubSubCredentialsConfig:
"""Populate default scope if scopes is None."""
super().__post_init__()
if not self.scopes:
self.scopes = PUBSUB_DEFAULT_SCOPE
# Set the token cache key
self._token_cache_key = PUBSUB_TOKEN_CACHE_KEY
return self
@@ -0,0 +1,99 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.agents.readonly_context import ReadonlyContext
from typing_extensions import override
from . import client
from . import message_tool
from ...features import experimental
from ...features import FeatureName
from ...tools.base_tool import BaseTool
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from ...tools.google_tool import GoogleTool
from .config import PubSubToolConfig
from .pubsub_credentials import PubSubCredentialsConfig
@experimental(FeatureName.PUBSUB_TOOLSET)
class PubSubToolset(BaseToolset):
"""Pub/Sub Toolset contains tools for interacting with Pub/Sub topics and subscriptions."""
def __init__(
self,
*,
tool_filter: ToolPredicate | list[str] | None = None,
credentials_config: PubSubCredentialsConfig | None = None,
pubsub_tool_config: PubSubToolConfig | None = None,
):
"""Initializes the PubSubToolset.
Args:
tool_filter: A predicate or list of tool names to filter the tools in
the toolset. If None, all tools are included.
credentials_config: The credentials configuration to use for
authenticating with Google Cloud.
pubsub_tool_config: The configuration for the Pub/Sub tools.
"""
super().__init__(tool_filter=tool_filter)
self._credentials_config = credentials_config
self._tool_settings = (
pubsub_tool_config if pubsub_tool_config else PubSubToolConfig()
)
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def get_tools(
self, readonly_context: ReadonlyContext | None = None
) -> list[BaseTool]:
"""Get tools from the toolset."""
all_tools = [
GoogleTool(
func=func,
credentials_config=self._credentials_config,
tool_settings=self._tool_settings,
)
for func in [
message_tool.publish_message,
message_tool.pull_messages,
message_tool.acknowledge_messages,
]
]
return [
tool
for tool in all_tools
if self._is_tool_selected(tool, readonly_context)
]
@override
async def close(self):
"""Clean up resources used by the toolset."""
client.cleanup_clients()
@@ -0,0 +1,135 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.tools.pubsub import client
from google.cloud import pubsub_v1
from google.oauth2.credentials import Credentials
import pytest
# Save original Pub/Sub classes before patching.
# This is necessary because create_autospec cannot be used on a mock object,
# and mock.patch.object(..., autospec=True) replaces the class with a mock.
# We need the original class to create spec'd mocks in side_effect.
ORIG_PUBLISHER = pubsub_v1.PublisherClient
ORIG_SUBSCRIBER = pubsub_v1.SubscriberClient
@pytest.fixture(autouse=True)
def cleanup_pubsub_clients():
"""Automatically clean up Pub/Sub client caches after each test.
This fixture runs automatically for all tests in this file,
ensuring that client caches are cleared between tests to prevent
state leakage and ensure test isolation.
"""
yield
client.cleanup_clients()
@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True)
def test_get_publisher_client(mock_publisher_client):
"""Test get_publisher_client factory."""
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
client.get_publisher_client(credentials=mock_creds)
mock_publisher_client.assert_called_once()
_, kwargs = mock_publisher_client.call_args
assert kwargs["credentials"] == mock_creds
assert "client_info" in kwargs
assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings)
assert kwargs["batch_settings"].max_messages == 1
@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True)
def test_get_publisher_client_with_options(mock_publisher_client):
"""Test get_publisher_client factory with options."""
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
mock_options = mock.create_autospec(
pubsub_v1.types.PublisherOptions, instance=True, spec_set=True
)
client.get_publisher_client(
credentials=mock_creds, publisher_options=mock_options
)
mock_publisher_client.assert_called_once()
_, kwargs = mock_publisher_client.call_args
assert kwargs["credentials"] == mock_creds
assert kwargs["publisher_options"] == mock_options
assert "client_info" in kwargs
assert isinstance(kwargs["batch_settings"], pubsub_v1.types.BatchSettings)
assert kwargs["batch_settings"].max_messages == 1
@mock.patch.object(pubsub_v1, "PublisherClient", autospec=True)
def test_get_publisher_client_caching(mock_publisher_client):
"""Test get_publisher_client caching behavior."""
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
mock_publisher_client.side_effect = [
mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True),
mock.create_autospec(ORIG_PUBLISHER, instance=True, spec_set=True),
]
# First call - should create client
client1 = client.get_publisher_client(credentials=mock_creds)
mock_publisher_client.assert_called_once()
# Second call with same args - should return cached client
client2 = client.get_publisher_client(credentials=mock_creds)
assert client1 is client2
mock_publisher_client.assert_called_once() # Still called only once
# Call with different args - should create new client
mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True)
client3 = client.get_publisher_client(credentials=mock_creds2)
assert client3 is not client1
assert mock_publisher_client.call_count == 2
@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True)
def test_get_subscriber_client(mock_subscriber_client):
"""Test get_subscriber_client factory."""
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
client.get_subscriber_client(credentials=mock_creds)
mock_subscriber_client.assert_called_once()
_, kwargs = mock_subscriber_client.call_args
assert kwargs["credentials"] == mock_creds
assert "client_info" in kwargs
@mock.patch.object(pubsub_v1, "SubscriberClient", autospec=True)
def test_get_subscriber_client_caching(mock_subscriber_client):
"""Test get_subscriber_client caching behavior."""
mock_creds = mock.create_autospec(Credentials, instance=True, spec_set=True)
mock_subscriber_client.side_effect = [
mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True),
mock.create_autospec(ORIG_SUBSCRIBER, instance=True, spec_set=True),
]
# First call - should create client
client1 = client.get_subscriber_client(credentials=mock_creds)
mock_subscriber_client.assert_called_once()
# Second call with same args - should return cached client
client2 = client.get_subscriber_client(credentials=mock_creds)
assert client1 is client2
mock_subscriber_client.assert_called_once() # Still called only once
# Call with different args - should create new client
mock_creds2 = mock.create_autospec(Credentials, instance=True, spec_set=True)
client3 = client.get_subscriber_client(credentials=mock_creds2)
assert client3 is not client1
assert mock_subscriber_client.call_count == 2
@@ -0,0 +1,27 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk.tools.pubsub.config import PubSubToolConfig
def test_pubsub_tool_config_init():
"""Test PubSubToolConfig initialization."""
config = PubSubToolConfig(project_id="my-project")
assert config.project_id == "my-project"
def test_pubsub_tool_config_default():
"""Test PubSubToolConfig default initialization."""
config = PubSubToolConfig()
assert config.project_id is None
@@ -0,0 +1,132 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from google.adk.tools.pubsub.pubsub_credentials import PUBSUB_DEFAULT_SCOPE
from google.adk.tools.pubsub.pubsub_credentials import PubSubCredentialsConfig
from google.auth.credentials import Credentials
import google.oauth2.credentials
import pytest
"""Test suite for PubSub credentials configuration validation.
This class tests the credential configuration logic that ensures
either existing credentials or client ID/secret pairs are provided.
"""
def test_pubsub_credentials_config_client_id_secret():
"""Test PubSubCredentialsConfig with client_id and client_secret.
Ensures that when client_id and client_secret are provided, the config
object is created with the correct attributes.
"""
config = PubSubCredentialsConfig(client_id="abc", client_secret="def")
assert config.client_id == "abc"
assert config.client_secret == "def"
assert config.scopes == PUBSUB_DEFAULT_SCOPE
assert config.credentials is None
def test_pubsub_credentials_config_existing_creds():
"""Test PubSubCredentialsConfig with existing generic credentials.
Ensures that when a generic Credentials object is provided, it is
stored correctly.
"""
mock_creds = mock.create_autospec(Credentials, instance=True)
config = PubSubCredentialsConfig(credentials=mock_creds)
assert config.credentials == mock_creds
assert config.client_id is None
assert config.client_secret is None
def test_pubsub_credentials_config_oauth2_creds():
"""Test PubSubCredentialsConfig with existing OAuth2 credentials.
Ensures that when a google.oauth2.credentials.Credentials object is
provided, the client_id, client_secret, and scopes are extracted
from the credentials object.
"""
mock_creds = mock.create_autospec(
google.oauth2.credentials.Credentials, instance=True
)
mock_creds.client_id = "oauth_client_id"
mock_creds.client_secret = "oauth_client_secret"
mock_creds.scopes = ["fake_scope"]
config = PubSubCredentialsConfig(credentials=mock_creds)
assert config.client_id == "oauth_client_id"
assert config.client_secret == "oauth_client_secret"
assert config.scopes == ["fake_scope"]
@pytest.mark.parametrize(
"credentials, client_id, client_secret",
[
# No arguments provided
(None, None, None),
# Only client_id is provided
(None, "abc", None),
],
)
def test_pubsub_credentials_config_validation_errors(
credentials, client_id, client_secret
):
"""Test PubSubCredentialsConfig validation errors.
Ensures that ValueError is raised when invalid combinations of credentials
and client ID/secret are provided.
Args:
credentials: The credentials object to pass.
client_id: The client ID to pass.
client_secret: The client secret to pass.
"""
with pytest.raises(
ValueError,
match=(
"Must provide either credentials or client_id and client_secret pair."
),
):
PubSubCredentialsConfig(
credentials=credentials,
client_id=client_id,
client_secret=client_secret,
)
def test_pubsub_credentials_config_both_credentials_and_client_provided():
"""Test PubSubCredentialsConfig validation errors.
Ensures that ValueError is raised when invalid combinations of credentials
and client ID/secret are provided.
Args:
credentials: The credentials object to pass.
client_id: The client ID to pass.
client_secret: The client secret to pass.
"""
with pytest.raises(
ValueError,
match=(
"Cannot provide both existing credentials and"
" client_id/client_secret/scopes."
),
):
PubSubCredentialsConfig(
credentials=mock.create_autospec(Credentials, instance=True),
client_id="abc",
client_secret="def",
)
@@ -0,0 +1,330 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from unittest import mock
from google.adk.tools.pubsub import client as pubsub_client_lib
from google.adk.tools.pubsub import message_tool
from google.adk.tools.pubsub.config import PubSubToolConfig
from google.api_core import future
from google.cloud import pubsub_v1
from google.cloud.pubsub_v1 import types
from google.oauth2.credentials import Credentials
from google.protobuf import timestamp_pb2
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
def test_publish_message(mock_get_publisher_client, mock_publish):
"""Test publish_message tool invocation."""
topic_name = "projects/my_project_id/topics/my_topic"
message = "Hello World"
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_publisher_client = mock.create_autospec(
pubsub_v1.PublisherClient, instance=True
)
mock_get_publisher_client.return_value = mock_publisher_client
mock_future = mock.create_autospec(future.Future, instance=True)
mock_future.result.return_value = "message_id"
mock_publisher_client.publish.return_value = mock_future
result = message_tool.publish_message(
topic_name, message, mock_credentials, tool_settings
)
assert result["message_id"] == "message_id"
mock_get_publisher_client.assert_called_once()
mock_publisher_client.publish.assert_called_once()
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
def test_publish_message_with_ordering_key(
mock_get_publisher_client, mock_publish
):
"""Test publish_message tool invocation with ordering_key."""
topic_name = "projects/my_project_id/topics/my_topic"
message = "Hello World"
ordering_key = "key1"
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_publisher_client = mock.create_autospec(
pubsub_v1.PublisherClient, instance=True
)
mock_get_publisher_client.return_value = mock_publisher_client
mock_future = mock.create_autospec(future.Future, instance=True)
mock_future.result.return_value = "message_id"
mock_publisher_client.publish.return_value = mock_future
result = message_tool.publish_message(
topic_name,
message,
mock_credentials,
tool_settings,
ordering_key=ordering_key,
)
assert result["message_id"] == "message_id"
mock_get_publisher_client.assert_called_once()
_, kwargs = mock_get_publisher_client.call_args
assert kwargs["publisher_options"].enable_message_ordering is True
mock_publisher_client.publish.assert_called_once()
# Verify ordering_key was passed
_, kwargs = mock_publisher_client.publish.call_args
assert kwargs["ordering_key"] == ordering_key
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
def test_publish_message_with_attributes(
mock_get_publisher_client, mock_publish
):
"""Test publish_message tool invocation with attributes."""
topic_name = "projects/my_project_id/topics/my_topic"
message = "Hello World"
attributes = {"key1": "value1", "key2": "value2"}
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_publisher_client = mock.create_autospec(
pubsub_v1.PublisherClient, instance=True
)
mock_get_publisher_client.return_value = mock_publisher_client
mock_future = mock.create_autospec(future.Future, instance=True)
mock_future.result.return_value = "message_id"
mock_publisher_client.publish.return_value = mock_future
result = message_tool.publish_message(
topic_name,
message,
mock_credentials,
tool_settings,
attributes=attributes,
)
assert result["message_id"] == "message_id"
mock_get_publisher_client.assert_called_once()
mock_publisher_client.publish.assert_called_once()
# Verify attributes were passed
_, kwargs = mock_publisher_client.publish.call_args
assert kwargs["key1"] == "value1"
assert kwargs["key2"] == "value2"
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_v1.PublisherClient, "publish", autospec=True)
@mock.patch.object(pubsub_client_lib, "get_publisher_client", autospec=True)
def test_publish_message_exception(mock_get_publisher_client, mock_publish):
"""Test publish_message tool invocation when exception occurs."""
topic_name = "projects/my_project_id/topics/my_topic"
message = "Hello World"
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_publisher_client = mock.create_autospec(
pubsub_v1.PublisherClient, instance=True
)
mock_get_publisher_client.return_value = mock_publisher_client
# Simulate an exception during publish
mock_publisher_client.publish.side_effect = Exception("Publish failed")
result = message_tool.publish_message(
topic_name,
message,
mock_credentials,
tool_settings,
)
assert result["status"] == "ERROR"
assert "Publish failed" in result["error_details"]
mock_get_publisher_client.assert_called_once()
mock_publisher_client.publish.assert_called_once()
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
def test_pull_messages(mock_get_subscriber_client):
"""Test pull_messages tool invocation."""
subscription_name = "projects/my_project_id/subscriptions/my_sub"
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_subscriber_client = mock.create_autospec(
pubsub_v1.SubscriberClient, instance=True
)
mock_get_subscriber_client.return_value = mock_subscriber_client
mock_response = mock.create_autospec(types.PullResponse, instance=True)
mock_message = mock.MagicMock()
mock_message.message.message_id = "123"
mock_message.message.data = b"Hello"
mock_message.message.attributes = {"key": "value"}
mock_message.message.ordering_key = "ABC"
mock_publish_time = mock.MagicMock()
mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z"
mock_message.message.publish_time = mock_publish_time
mock_message.ack_id = "ack_123"
mock_response.received_messages = [mock_message]
mock_subscriber_client.pull.return_value = mock_response
result = message_tool.pull_messages(
subscription_name, mock_credentials, tool_settings
)
expected_message = {
"message_id": "123",
"data": "Hello",
"attributes": {"key": "value"},
"ordering_key": "ABC",
"publish_time": "2023-01-01T00:00:00Z",
"ack_id": "ack_123",
}
assert result["messages"] == [expected_message]
mock_get_subscriber_client.assert_called_once()
mock_subscriber_client.pull.assert_called_once_with(
subscription=subscription_name, max_messages=1
)
mock_subscriber_client.acknowledge.assert_not_called()
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
def test_pull_messages_auto_ack(mock_get_subscriber_client):
"""Test pull_messages tool invocation with auto_ack."""
subscription_name = "projects/my_project_id/subscriptions/my_sub"
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_subscriber_client = mock.create_autospec(
pubsub_v1.SubscriberClient, instance=True
)
mock_get_subscriber_client.return_value = mock_subscriber_client
mock_response = mock.create_autospec(types.PullResponse, instance=True)
mock_message = mock.MagicMock()
mock_message.message.message_id = "123"
mock_message.message.data = b"Hello"
mock_message.message.attributes = {}
mock_publish_time = mock.MagicMock()
mock_publish_time.rfc3339.return_value = "2023-01-01T00:00:00Z"
mock_message.message.publish_time = mock_publish_time
mock_message.ack_id = "ack_123"
mock_response.received_messages = [mock_message]
mock_subscriber_client.pull.return_value = mock_response
result = message_tool.pull_messages(
subscription_name,
mock_credentials,
tool_settings,
max_messages=5,
auto_ack=True,
)
assert len(result["messages"]) == 1
mock_get_subscriber_client.assert_called_once()
mock_subscriber_client.pull.assert_called_once_with(
subscription=subscription_name, max_messages=5
)
mock_subscriber_client.acknowledge.assert_called_once_with(
subscription=subscription_name, ack_ids=["ack_123"]
)
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
def test_pull_messages_exception(mock_get_subscriber_client):
"""Test pull_messages tool invocation when exception occurs."""
subscription_name = "projects/my_project_id/subscriptions/my_sub"
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_subscriber_client = mock.create_autospec(
pubsub_v1.SubscriberClient, instance=True
)
mock_get_subscriber_client.return_value = mock_subscriber_client
mock_subscriber_client.pull.side_effect = Exception("Pull failed")
result = message_tool.pull_messages(
subscription_name, mock_credentials, tool_settings
)
assert result["status"] == "ERROR"
assert "Pull failed" in result["error_details"]
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
def test_acknowledge_messages(mock_get_subscriber_client):
"""Test acknowledge_messages tool invocation."""
subscription_name = "projects/my_project_id/subscriptions/my_sub"
ack_ids = ["ack1", "ack2"]
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_subscriber_client = mock.create_autospec(
pubsub_v1.SubscriberClient, instance=True
)
mock_get_subscriber_client.return_value = mock_subscriber_client
result = message_tool.acknowledge_messages(
subscription_name, ack_ids, mock_credentials, tool_settings
)
assert result["status"] == "SUCCESS"
mock_get_subscriber_client.assert_called_once()
mock_subscriber_client.acknowledge.assert_called_once_with(
subscription=subscription_name, ack_ids=ack_ids
)
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(pubsub_client_lib, "get_subscriber_client", autospec=True)
def test_acknowledge_messages_exception(mock_get_subscriber_client):
"""Test acknowledge_messages tool invocation when exception occurs."""
subscription_name = "projects/my_project_id/subscriptions/my_sub"
ack_ids = ["ack1"]
mock_credentials = mock.create_autospec(Credentials, instance=True)
tool_settings = PubSubToolConfig(project_id="my_project_id")
mock_subscriber_client = mock.create_autospec(
pubsub_v1.SubscriberClient, instance=True
)
mock_get_subscriber_client.return_value = mock_subscriber_client
mock_subscriber_client.acknowledge.side_effect = Exception("Ack failed")
result = message_tool.acknowledge_messages(
subscription_name, ack_ids, mock_credentials, tool_settings
)
assert result["status"] == "ERROR"
assert "Ack failed" in result["error_details"]
@@ -0,0 +1,131 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.tools.google_tool import GoogleTool
from google.adk.tools.pubsub import PubSubCredentialsConfig
from google.adk.tools.pubsub import PubSubToolset
from google.adk.tools.pubsub.config import PubSubToolConfig
import pytest
@pytest.mark.asyncio
async def test_pubsub_toolset_tools_default():
"""Test default PubSub toolset.
This test verifies the behavior of the PubSub toolset when no filter is
specified.
"""
credentials_config = PubSubCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = PubSubToolset(
credentials_config=credentials_config, pubsub_tool_config=None
)
# Verify that the tool config is initialized to default values.
assert isinstance(toolset._tool_settings, PubSubToolConfig) # pylint: disable=protected-access
assert toolset._tool_settings.__dict__ == PubSubToolConfig().__dict__ # pylint: disable=protected-access
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 3
assert all(isinstance(tool, GoogleTool) for tool in tools)
expected_tool_names = set([
"publish_message",
"pull_messages",
"acknowledge_messages",
])
actual_tool_names = {tool.name for tool in tools}
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
"selected_tools",
[
pytest.param([], id="None"),
pytest.param(["publish_message"], id="publish"),
pytest.param(["pull_messages"], id="pull"),
pytest.param(["acknowledge_messages"], id="ack"),
],
)
@pytest.mark.asyncio
async def test_pubsub_toolset_tools_selective(selected_tools):
"""Test PubSub toolset with filter.
This test verifies the behavior of the PubSub toolset when filter is
specified. A use case for this would be when the agent builder wants to
use only a subset of the tools provided by the toolset.
Args:
selected_tools: The list of tools to select from the toolset.
"""
credentials_config = PubSubCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = PubSubToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(selected_tools)
assert all(isinstance(tool, GoogleTool) for tool in tools)
expected_tool_names = set(selected_tools)
actual_tool_names = {tool.name for tool in tools}
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
("selected_tools", "returned_tools"),
[
pytest.param(["unknown"], [], id="all-unknown"),
pytest.param(
["unknown", "publish_message"],
["publish_message"],
id="mixed-known-unknown",
),
],
)
@pytest.mark.asyncio
async def test_pubsub_toolset_unknown_tool(selected_tools, returned_tools):
"""Test PubSub toolset with filter.
This test verifies the behavior of the PubSub toolset when filter is
specified with an unknown tool.
Args:
selected_tools: The list of tools to select from the toolset.
returned_tools: The list of tools that are expected to be returned.
"""
credentials_config = PubSubCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = PubSubToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(returned_tools)
assert all(isinstance(tool, GoogleTool) for tool in tools)
expected_tool_names = set(returned_tools)
actual_tool_names = {tool.name for tool in tools}
assert actual_tool_names == expected_tool_names