feat: add Spanner first-party toolset (breaking change to BigQueryTool, consolidating into generic GoogleTool)

Spanner toolset support basic operations to interact with Spanner table metadata and query results.

Consolidate BigQueryTool into generic GoogleTool, so that BigQueryToolset and SpannerToolset can share.

PiperOrigin-RevId: 794259782
This commit is contained in:
Google Team Member
2025-08-12 13:59:05 -07:00
committed by Copybara-Service
parent 10e3dfab1a
commit 1fc8d20ae8
25 changed files with 1716 additions and 320 deletions
+1
View File
@@ -34,6 +34,7 @@ dependencies = [
"google-api-python-client>=2.157.0", # Google API client discovery
"google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store.
"google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool
"google-cloud-spanner>=3.56.0", # For Spanner database
"google-cloud-speech>=2.30.0", # For Audio Transcription
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.21.1", # Google GenAI SDK
+252
View File
@@ -0,0 +1,252 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from typing import List
from typing import Optional
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
import google.auth.credentials
from google.auth.exceptions import RefreshError
from google.auth.transport.requests import Request
import google.oauth2.credentials
from pydantic import BaseModel
from pydantic import model_validator
from ..auth.auth_credential import AuthCredential
from ..auth.auth_credential import AuthCredentialTypes
from ..auth.auth_credential import OAuth2Auth
from ..auth.auth_tool import AuthConfig
from ..utils.feature_decorator import experimental
from .tool_context import ToolContext
@experimental
class BaseGoogleCredentialsConfig(BaseModel):
"""Base Google Credentials Configuration for Google API tools (Experimental).
Please do not use this in production, as it may be deprecated later.
"""
# Configure the model to allow arbitrary types like Credentials
model_config = {"arbitrary_types_allowed": True}
credentials: Optional[google.auth.credentials.Credentials] = None
"""The existing auth credentials to use. If set, this credential will be used
for every end user, end users don't need to be involved in the oauthflow. This
field is mutually exclusive with client_id, client_secret and scopes.
Don't set this field unless you are sure this credential has the permission to
access every end user's data.
Example usage 1: When the agent is deployed in Google Cloud environment and
the service account (used as application default credentials) has access to
all the required Google Cloud resource. Setting this credential to allow user
to access the Google Cloud resource without end users going through oauth
flow.
To get application default credential, use: `google.auth.default(...)`. See
more details in
https://cloud.google.com/docs/authentication/application-default-credentials.
Example usage 2: When the agent wants to access the user's Google Cloud
resources using the service account key credentials.
To load service account key credentials, use:
`google.auth.load_credentials_from_file(...)`. See more details in
https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
When the deployed environment cannot provide a pre-existing credential,
consider setting below client_id, client_secret and scope for end users to go
through oauth flow, so that agent can access the user data.
"""
client_id: Optional[str] = None
"""the oauth client ID to use."""
client_secret: Optional[str] = None
"""the oauth client secret to use."""
scopes: Optional[List[str]] = None
"""the scopes to use."""
_token_cache_key: Optional[str] = None
"""The key to cache the token in the tool context."""
@model_validator(mode="after")
def __post_init__(self) -> BaseGoogleCredentialsConfig:
"""Validate that either credentials or client ID/secret are provided."""
if not self.credentials and (not self.client_id or not self.client_secret):
raise ValueError(
"Must provide either credentials or client_id and client_secret pair."
)
if self.credentials and (
self.client_id or self.client_secret or self.scopes
):
raise ValueError(
"Cannot provide both existing credentials and"
" client_id/client_secret/scopes."
)
if self.credentials and isinstance(
self.credentials, google.oauth2.credentials.Credentials
):
self.client_id = self.credentials.client_id
self.client_secret = self.credentials.client_secret
self.scopes = self.credentials.scopes
return self
class GoogleCredentialsManager:
"""Manages Google API credentials with automatic refresh and OAuth flow handling.
This class centralizes credential management so multiple tools can share
the same authenticated session without duplicating OAuth logic.
"""
def __init__(
self,
credentials_config: BaseGoogleCredentialsConfig,
):
"""Initialize the credential manager.
Args:
credentials_config: Credentials containing client id and client secrete
or default credentials
"""
self.credentials_config = credentials_config
async def get_valid_credentials(
self, tool_context: ToolContext
) -> Optional[google.auth.credentials.Credentials]:
"""Get valid credentials, handling refresh and OAuth flow as needed.
Args:
tool_context: The tool context for OAuth flow and state management
Returns:
Valid Credentials object, or None if OAuth flow is needed
"""
# First, try to get credentials from the tool context
creds_json = (
tool_context.state.get(self.credentials_config._token_cache_key, None)
if self.credentials_config._token_cache_key
else None
)
creds = (
google.oauth2.credentials.Credentials.from_authorized_user_info(
json.loads(creds_json), self.credentials_config.scopes
)
if creds_json
else None
)
# If credentails are empty use the default credential
if not creds:
creds = self.credentials_config.credentials
# If non-oauth credentials are provided then use them as is. This helps
# in flows such as service account keys
if creds and not isinstance(creds, google.oauth2.credentials.Credentials):
return creds
# Check if we have valid credentials
if creds and creds.valid:
return creds
# Try to refresh expired credentials
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
# Cache the refreshed credentials if token cache key is set
if self.credentials_config._token_cache_key:
tool_context.state[self.credentials_config._token_cache_key] = (
creds.to_json()
)
return creds
except RefreshError:
# Refresh failed, need to re-authenticate
pass
# Need to perform OAuth flow
return await self._perform_oauth_flow(tool_context)
async def _perform_oauth_flow(
self, tool_context: ToolContext
) -> Optional[google.oauth2.credentials.Credentials]:
"""Perform OAuth flow to get new credentials.
Args:
tool_context: The tool context for OAuth flow
Returns:
New Credentials object, or None if flow is in progress
"""
# Create OAuth configuration
auth_scheme = OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
tokenUrl="https://oauth2.googleapis.com/token",
scopes={
scope: f"Access to {scope}"
for scope in self.credentials_config.scopes
},
)
)
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id=self.credentials_config.client_id,
client_secret=self.credentials_config.client_secret,
),
)
# Check if OAuth response is available
auth_response = tool_context.get_auth_response(
AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential)
)
if auth_response:
# OAuth flow completed, create credentials
creds = google.oauth2.credentials.Credentials(
token=auth_response.oauth2.access_token,
refresh_token=auth_response.oauth2.refresh_token,
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
client_id=self.credentials_config.client_id,
client_secret=self.credentials_config.client_secret,
scopes=list(self.credentials_config.scopes),
)
# Cache the new credentials if token cache key is set
if self.credentials_config._token_cache_key:
tool_context.state[self.credentials_config._token_cache_key] = (
creds.to_json()
)
return creds
else:
# Request OAuth flow
tool_context.request_credential(
AuthConfig(
auth_scheme=auth_scheme,
raw_auth_credential=auth_credential,
)
)
return None
@@ -28,11 +28,9 @@ definition. The rationales to have customized tool are:
"""
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
from .bigquery_toolset import BigQueryToolset
__all__ = [
"BigQueryTool",
"BigQueryToolset",
"BigQueryCredentialsConfig",
]
@@ -14,227 +14,28 @@
from __future__ import annotations
import json
from typing import List
from typing import Optional
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
import google.auth.credentials
from google.auth.exceptions import RefreshError
from google.auth.transport.requests import Request
import google.oauth2.credentials
from pydantic import BaseModel
from pydantic import model_validator
from ...auth.auth_credential import AuthCredential
from ...auth.auth_credential import AuthCredentialTypes
from ...auth.auth_credential import OAuth2Auth
from ...auth.auth_tool import AuthConfig
from ...utils.feature_decorator import experimental
from ..tool_context import ToolContext
from .._google_credentials import BaseGoogleCredentialsConfig
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"]
@experimental
class BigQueryCredentialsConfig(BaseModel):
"""Configuration for Google API tools (Experimental).
class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig):
"""BigQuery Credentials Configuration for Google API tools (Experimental).
Please do not use this in production, as it may be deprecated later.
"""
# Configure the model to allow arbitrary types like Credentials
model_config = {"arbitrary_types_allowed": True}
credentials: Optional[google.auth.credentials.Credentials] = None
"""The existing auth credentials to use. If set, this credential will be used
for every end user, end users don't need to be involved in the oauthflow. This
field is mutually exclusive with client_id, client_secret and scopes.
Don't set this field unless you are sure this credential has the permission to
access every end user's data.
Example usage 1: When the agent is deployed in Google Cloud environment and
the service account (used as application default credentials) has access to
all the required BigQuery resource. Setting this credential to allow user to
access the BigQuery resource without end users going through oauth flow.
To get application default credential, use: `google.auth.default(...)`. See more
details in https://cloud.google.com/docs/authentication/application-default-credentials.
Example usage 2: When the agent wants to access the user's BigQuery resources
using the service account key credentials.
To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`.
See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys.
When the deployed environment cannot provide a pre-existing credential,
consider setting below client_id, client_secret and scope for end users to go
through oauth flow, so that agent can access the user data.
"""
client_id: Optional[str] = None
"""the oauth client ID to use."""
client_secret: Optional[str] = None
"""the oauth client secret to use."""
scopes: Optional[List[str]] = None
"""the scopes to use."""
@model_validator(mode="after")
def __post_init__(self) -> BigQueryCredentialsConfig:
"""Validate that either credentials or client ID/secret are provided."""
if not self.credentials and (not self.client_id or not self.client_secret):
raise ValueError(
"Must provide either credentials or client_id and client_secret pair."
)
if self.credentials and (
self.client_id or self.client_secret or self.scopes
):
raise ValueError(
"Cannot provide both existing credentials and"
" client_id/client_secret/scopes."
)
if self.credentials and isinstance(
self.credentials, google.oauth2.credentials.Credentials
):
self.client_id = self.credentials.client_id
self.client_secret = self.credentials.client_secret
self.scopes = self.credentials.scopes
"""Populate default scope if scopes is None."""
super().__post_init__()
if not self.scopes:
self.scopes = BIGQUERY_DEFAULT_SCOPE
# Set the token cache key
self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY
return self
class BigQueryCredentialsManager:
"""Manages Google API credentials with automatic refresh and OAuth flow handling.
This class centralizes credential management so multiple tools can share
the same authenticated session without duplicating OAuth logic.
"""
def __init__(self, credentials_config: BigQueryCredentialsConfig):
"""Initialize the credential manager.
Args:
credentials_config: Credentials containing client id and client secrete
or default credentials
"""
self.credentials_config = credentials_config
async def get_valid_credentials(
self, tool_context: ToolContext
) -> Optional[google.auth.credentials.Credentials]:
"""Get valid credentials, handling refresh and OAuth flow as needed.
Args:
tool_context: The tool context for OAuth flow and state management
Returns:
Valid Credentials object, or None if OAuth flow is needed
"""
# First, try to get credentials from the tool context
creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None)
creds = (
google.oauth2.credentials.Credentials.from_authorized_user_info(
json.loads(creds_json), self.credentials_config.scopes
)
if creds_json
else None
)
# If credentails are empty use the default credential
if not creds:
creds = self.credentials_config.credentials
# If non-oauth credentials are provided then use them as is. This helps
# in flows such as service account keys
if creds and not isinstance(creds, google.oauth2.credentials.Credentials):
return creds
# Check if we have valid credentials
if creds and creds.valid:
return creds
# Try to refresh expired credentials
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
if creds.valid:
# Cache the refreshed credentials
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
return creds
except RefreshError:
# Refresh failed, need to re-authenticate
pass
# Need to perform OAuth flow
return await self._perform_oauth_flow(tool_context)
async def _perform_oauth_flow(
self, tool_context: ToolContext
) -> Optional[google.oauth2.credentials.Credentials]:
"""Perform OAuth flow to get new credentials.
Args:
tool_context: The tool context for OAuth flow
required_scopes: Set of required OAuth scopes
Returns:
New Credentials object, or None if flow is in progress
"""
# Create OAuth configuration
auth_scheme = OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
tokenUrl="https://oauth2.googleapis.com/token",
scopes={
scope: f"Access to {scope}"
for scope in self.credentials_config.scopes
},
)
)
)
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id=self.credentials_config.client_id,
client_secret=self.credentials_config.client_secret,
),
)
# Check if OAuth response is available
auth_response = tool_context.get_auth_response(
AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential)
)
if auth_response:
# OAuth flow completed, create credentials
creds = google.oauth2.credentials.Credentials(
token=auth_response.oauth2.access_token,
refresh_token=auth_response.oauth2.refresh_token,
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
client_id=self.credentials_config.client_id,
client_secret=self.credentials_config.client_secret,
scopes=list(self.credentials_config.scopes),
)
# Cache the new credentials
tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json()
return creds
else:
# Request OAuth flow
tool_context.request_credential(
AuthConfig(
auth_scheme=auth_scheme,
raw_auth_credential=auth_credential,
)
)
return None
@@ -26,9 +26,9 @@ from . import query_tool
from ...tools.base_tool import BaseTool
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from ...tools.google_tool import GoogleTool
from ...utils.feature_decorator import experimental
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
from .config import BigQueryToolConfig
@@ -45,7 +45,9 @@ class BigQueryToolset(BaseToolset):
):
self.tool_filter = tool_filter
self._credentials_config = credentials_config
self._tool_config = bigquery_tool_config
self._tool_settings = (
bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig()
)
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
@@ -67,17 +69,17 @@ class BigQueryToolset(BaseToolset):
) -> List[BaseTool]:
"""Get tools from the toolset."""
all_tools = [
BigQueryTool(
GoogleTool(
func=func,
credentials_config=self._credentials_config,
bigquery_tool_config=self._tool_config,
tool_settings=self._tool_settings,
)
for func in [
metadata_tool.get_dataset_info,
metadata_tool.get_table_info,
metadata_tool.list_dataset_ids,
metadata_tool.list_table_ids,
query_tool.get_execute_sql(self._tool_config),
query_tool.get_execute_sql(self._tool_settings),
]
]
@@ -30,7 +30,7 @@ def ask_data_insights(
user_query_with_context: str,
table_references: List[Dict[str, str]],
credentials: Credentials,
config: BigQueryToolConfig,
settings: BigQueryToolConfig,
) -> Dict[str, Any]:
"""Answers questions about structured data in BigQuery tables using natural language.
@@ -53,7 +53,7 @@ def ask_data_insights(
table_references (List[Dict[str, str]]): A list of dictionaries, each
specifying a BigQuery table to be used as context for the question.
credentials (Credentials): The credentials to use for the request.
config (BigQueryToolConfig): The configuration for the tool.
settings (BigQueryToolConfig): The settings for the tool.
Returns:
A dictionary with two keys:
@@ -135,7 +135,7 @@ def ask_data_insights(
}
resp = _get_stream(
ca_url, ca_payload, headers, config.max_query_result_rows
ca_url, ca_payload, headers, settings.max_query_result_rows
)
except Exception as ex: # pylint: disable=broad-except
return {
+13 -13
View File
@@ -36,7 +36,7 @@ def execute_sql(
project_id: str,
query: str,
credentials: Credentials,
config: BigQueryToolConfig,
settings: BigQueryToolConfig,
tool_context: ToolContext,
) -> dict:
"""Run a BigQuery or BigQuery ML SQL query in the project and return the result.
@@ -46,7 +46,7 @@ def execute_sql(
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
config (BigQueryToolConfig): The configuration for the tool.
settings (BigQueryToolConfig): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
@@ -89,7 +89,7 @@ def execute_sql(
# BigQuery connection properties where applicable
bq_connection_properties = None
if not config or config.write_mode == WriteMode.BLOCKED:
if not settings or settings.write_mode == WriteMode.BLOCKED:
dry_run_query_job = bq_client.query(
query,
project=project_id,
@@ -100,7 +100,7 @@ def execute_sql(
"status": "ERROR",
"error_details": "Read-only mode only supports SELECT statements.",
}
elif config.write_mode == WriteMode.PROTECTED:
elif settings.write_mode == WriteMode.PROTECTED:
# In protected write mode, write operation only to a temporary artifact is
# allowed. This artifact must have been created in a BigQuery session. In
# such a scenario the session info (session id and the anonymous dataset
@@ -161,7 +161,7 @@ def execute_sql(
query,
job_config=job_config,
project=project_id,
max_results=config.max_query_result_rows,
max_results=settings.max_query_result_rows,
)
rows = []
for row in row_iterator:
@@ -177,8 +177,8 @@ def execute_sql(
result = {"status": "SUCCESS", "rows": rows}
if (
config.max_query_result_rows is not None
and len(rows) == config.max_query_result_rows
settings.max_query_result_rows is not None
and len(rows) == settings.max_query_result_rows
):
result["result_is_likely_truncated"] = True
return result
@@ -462,19 +462,19 @@ _execute_sql_protecetd_write_examples = """
"""
def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
"""Get the execute_sql tool customized as per the given tool config.
def get_execute_sql(settings: BigQueryToolConfig) -> Callable[..., dict]:
"""Get the execute_sql tool customized as per the given tool settings.
Args:
config: BigQuery tool configuration indicating the behavior of the
settings: BigQuery tool settings indicating the behavior of the
execute_sql tool.
Returns:
callable[..., dict]: A version of the execute_sql tool respecting the tool
config.
settings.
"""
if not config or config.write_mode == WriteMode.BLOCKED:
if not settings or settings.write_mode == WriteMode.BLOCKED:
return execute_sql
# Create a new function object using the original function's code and globals.
@@ -495,7 +495,7 @@ def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
functools.update_wrapper(execute_sql_wrapper, execute_sql)
# Now, set the new docstring
if config.write_mode == WriteMode.PROTECTED:
if settings.write_mode == WriteMode.PROTECTED:
examples = _execute_sql_protecetd_write_examples
else:
examples = _execute_sql_write_examples
@@ -20,19 +20,19 @@ from typing import Callable
from typing import Optional
from google.auth.credentials import Credentials
from pydantic import BaseModel
from typing_extensions import override
from ...utils.feature_decorator import experimental
from ..function_tool import FunctionTool
from ..tool_context import ToolContext
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_credentials import BigQueryCredentialsManager
from .config import BigQueryToolConfig
from ..utils.feature_decorator import experimental
from ._google_credentials import BaseGoogleCredentialsConfig
from ._google_credentials import GoogleCredentialsManager
from .function_tool import FunctionTool
from .tool_context import ToolContext
@experimental
class BigQueryTool(FunctionTool):
"""GoogleApiTool class for tools that call Google APIs.
class GoogleTool(FunctionTool):
"""GoogleTool class for tools that call Google APIs.
This class is for developers to handcraft customized Google API tools rather
than auto generate Google API tools based on API specs.
@@ -46,8 +46,8 @@ class BigQueryTool(FunctionTool):
self,
func: Callable[..., Any],
*,
credentials_config: Optional[BigQueryCredentialsConfig] = None,
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
credentials_config: Optional[BaseGoogleCredentialsConfig] = None,
tool_settings: Optional[BaseModel] = None,
):
"""Initialize the Google API tool.
@@ -56,18 +56,18 @@ class BigQueryTool(FunctionTool):
'credential" parameter
credentials_config: credentials config used to call Google API. If None,
then we don't hanlde the auth logic
tool_settings: Tool-specific settings. This settings should be provided
by each toolset that uses this class to create customized tools.
"""
super().__init__(func=func)
self._ignore_params.append("credentials")
self._ignore_params.append("config")
self._ignore_params.append("settings")
self._credentials_manager = (
BigQueryCredentialsManager(credentials_config)
GoogleCredentialsManager(credentials_config)
if credentials_config
else None
)
self._tool_config = (
bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig()
)
self._tool_settings = tool_settings
@override
async def run_async(
@@ -96,7 +96,7 @@ class BigQueryTool(FunctionTool):
# Execute the tool's specific logic with valid credentials
return await self._run_async_with_credential(
credentials, self._tool_config, args, tool_context
credentials, self._tool_settings, args, tool_context
)
except Exception as ex:
@@ -108,7 +108,7 @@ class BigQueryTool(FunctionTool):
async def _run_async_with_credential(
self,
credentials: Credentials,
tool_config: BigQueryToolConfig,
tool_settings: BaseModel,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
@@ -116,6 +116,7 @@ class BigQueryTool(FunctionTool):
Args:
credentials: Valid Google OAuth credentials
tool_settings: Tool settings
args: Arguments passed to the tool
tool_context: Tool execution context
@@ -126,6 +127,6 @@ class BigQueryTool(FunctionTool):
signature = inspect.signature(self.func)
if "credentials" in signature.parameters:
args_to_call["credentials"] = credentials
if "config" in signature.parameters:
args_to_call["config"] = tool_config
if "settings" in signature.parameters:
args_to_call["settings"] = tool_settings
return await super().run_async(args=args_to_call, tool_context=tool_context)
+40
View File
@@ -0,0 +1,40 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner Tools (Experimental).
Spanner Tools under this module are hand crafted and customized while the tools
under google.adk.tools.google_api_tool are auto generated based on API
definition. The rationales to have customized tool are:
1. A dedicated Spanner toolset to provide an easier, integrated way to interact
with Spanner database and tables for building AI Agent applications quickly.
2. We want to provide more high-level tools like Search, ML.Predict, and Graph
etc.
3. We want to provide extra access guardrails and controls in those tools.
For example, execute_sql can't arbitrarily mutate existing data.
4. We want to provide Spanner best practices and knowledge assistants for ad-hoc
analytics queries.
5. Use Spanner Toolset for more customization and control to interact with
Spanner database and tables.
"""
from . import spanner_credentials
from .spanner_toolset import SpannerToolset
SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig
__all__ = [
"SpannerToolset",
"SpannerCredentialsConfig",
]
+33
View File
@@ -0,0 +1,33 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.auth.credentials import Credentials
from google.cloud import spanner
from ... import version
USER_AGENT = f"adk-spanner-tool google-adk/{version.__version__}"
def get_spanner_client(
*, project: str, credentials: Credentials
) -> spanner.Client:
"""Get a Spanner client."""
spanner_client = spanner.Client(project=project, credentials=credentials)
spanner_client._client_info.user_agent = USER_AGENT
return spanner_client
File diff suppressed because it is too large Load Diff
+114
View File
@@ -0,0 +1,114 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from google.auth.credentials import Credentials
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from . import client
from ..tool_context import ToolContext
from .settings import SpannerToolSettings
DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50
def execute_sql(
project_id: str,
instance_id: str,
database_id: str,
query: str,
credentials: Credentials,
settings: SpannerToolSettings,
tool_context: ToolContext,
) -> dict:
"""Run a Spanner Read-Only query in the spanner database and return the result.
Args:
project_id (str): The GCP project id in which the spanner database
resides.
instance_id (str): The instance id of the spanner database.
database_id (str): The database id of the spanner database.
query (str): The Spanner SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
settings (SpannerToolSettings): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
dict: Dictionary with the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
Fetch data or insights from a table:
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) AS count FROM my_table")
{
"status": "SUCCESS",
"rows": [
[100]
]
}
Note:
This is running with Read-Only Transaction for query that only read data.
"""
try:
# Get Spanner client
spanner_client = client.get_spanner_client(
project=project_id, credentials=credentials
)
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)
if database.database_dialect == DatabaseDialect.POSTGRESQL:
return {
"status": "ERROR",
"error_details": "PostgreSQL dialect is not supported.",
}
with database.snapshot() as snapshot:
result_set = snapshot.execute_sql(query)
rows = []
counter = (
settings.max_executed_query_result_rows
if settings and settings.max_executed_query_result_rows > 0
else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS
)
for row in result_set:
try:
# if the json serialization of the row succeeds, use it as is
json.dumps(row)
except:
row = str(row)
rows.append(row)
counter -= 1
if counter <= 0:
break
result = {"status": "SUCCESS", "rows": rows}
if counter <= 0:
result["result_is_likely_truncated"] = True
return result
except Exception as ex:
return {
"status": "ERROR",
"error_details": str(ex),
}
+46
View File
@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import Enum
from typing import List
from pydantic import BaseModel
from ...utils.feature_decorator import experimental
class Capabilities(Enum):
"""Capabilities indicating what type of operation tools are allowed to be performed on Spanner."""
DATA_READ = 'data_read'
"""Read only data operations tools are allowed."""
@experimental('Tool settings defaults may have breaking change in the future.')
class SpannerToolSettings(BaseModel):
"""Settings for Spanner tools."""
capabilities: List[Capabilities] = [
Capabilities.DATA_READ,
]
"""Allowed capabilities for Spanner tools.
By default, the tool will allow only read operations. This behaviour may
change in future versions.
"""
max_executed_query_result_rows: int = 50
"""Maximum number of rows to return from a query result."""
@@ -0,0 +1,41 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ...utils.feature_decorator import experimental
from .._google_credentials import BaseGoogleCredentialsConfig
SPANNER_TOKEN_CACHE_KEY = "spanner_token_cache"
SPANNER_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/spanner.data"]
@experimental
class SpannerCredentialsConfig(BaseGoogleCredentialsConfig):
"""Spanner Credentials Configuration for Google API tools (Experimental).
Please do not use this in production, as it may be deprecated later.
"""
def __post_init__(self) -> SpannerCredentialsConfig:
"""Populate default scope if scopes is None."""
super().__post_init__()
if not self.scopes:
self.scopes = SPANNER_DEFAULT_SCOPE
# Set the token cache key
self._token_cache_key = SPANNER_TOKEN_CACHE_KEY
return self
@@ -0,0 +1,111 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from google.adk.agents.readonly_context import ReadonlyContext
from typing_extensions import override
from . import metadata_tool
from . import query_tool
from ...tools.base_tool import BaseTool
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from ...tools.google_tool import GoogleTool
from ...utils.feature_decorator import experimental
from .settings import Capabilities
from .settings import SpannerToolSettings
from .spanner_credentials import SpannerCredentialsConfig
@experimental
class SpannerToolset(BaseToolset):
"""Spanner Toolset contains tools for interacting with Spanner data, database and table information."""
def __init__(
self,
*,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
credentials_config: Optional[SpannerCredentialsConfig] = None,
spanner_tool_settings: Optional[SpannerToolSettings] = None,
):
self.tool_filter = tool_filter
self._credentials_config = credentials_config
self._tool_settings = (
spanner_tool_settings
if spanner_tool_settings
else SpannerToolSettings()
)
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[BaseTool]:
"""Get tools from the toolset."""
all_tools = [
GoogleTool(
func=func,
credentials_config=self._credentials_config,
tool_settings=self._tool_settings,
)
for func in [
# Metadata tools
metadata_tool.list_table_names,
metadata_tool.list_table_indexes,
metadata_tool.list_table_index_columns,
metadata_tool.list_named_schemas,
metadata_tool.get_table_schema,
]
]
# Query tools
if (
self._tool_settings
and Capabilities.DATA_READ in self._tool_settings.capabilities
):
all_tools.append(
GoogleTool(
func=query_tool.execute_sql,
credentials_config=self._credentials_config,
tool_settings=self._tool_settings,
)
)
return [
tool
for tool in all_tools
if self._is_tool_selected(tool, readonly_context)
]
@override
async def close(self):
pass
@@ -74,8 +74,8 @@ def test_ask_data_insights_success(mock_get_stream):
# 2. Create mock inputs for the function call
mock_creds = mock.Mock()
mock_creds.token = "fake-token"
mock_config = mock.Mock()
mock_config.max_query_result_rows = 100
mock_settings = mock.Mock()
mock_settings.max_query_result_rows = 100
# 3. Call the function under test
result = data_insights_tool.ask_data_insights(
@@ -83,7 +83,7 @@ def test_ask_data_insights_success(mock_get_stream):
user_query_with_context="test query",
table_references=[],
credentials=mock_creds,
config=mock_config,
settings=mock_settings,
)
# 4. Assert the results are as expected
@@ -101,7 +101,7 @@ def test_ask_data_insights_handles_exception(mock_get_stream):
# 2. Create mock inputs
mock_creds = mock.Mock()
mock_creds.token = "fake-token"
mock_config = mock.Mock()
mock_settings = mock.Mock()
# 3. Call the function
result = data_insights_tool.ask_data_insights(
@@ -109,7 +109,7 @@ def test_ask_data_insights_handles_exception(mock_get_stream):
user_query_with_context="test query",
table_references=[],
credentials=mock_creds,
config=mock_config,
settings=mock_settings,
)
# 4. Assert that the error was caught and formatted correctly
@@ -37,7 +37,7 @@ import pytest
async def get_tool(
name: str, tool_config: Optional[BigQueryToolConfig] = None
name: str, tool_settings: Optional[BigQueryToolConfig] = None
) -> BaseTool:
"""Get a tool from BigQuery toolset.
@@ -54,7 +54,7 @@ async def get_tool(
toolset = BigQueryToolset(
credentials_config=credentials_config,
tool_filter=[name],
bigquery_tool_config=tool_config,
bigquery_tool_config=tool_settings,
)
tools = await toolset.get_tools()
@@ -64,7 +64,7 @@ async def get_tool(
@pytest.mark.parametrize(
("tool_config",),
("tool_settings",),
[
pytest.param(None, id="no-config"),
pytest.param(BigQueryToolConfig(), id="default-config"),
@@ -75,14 +75,14 @@ async def get_tool(
],
)
@pytest.mark.asyncio
async def test_execute_sql_declaration_read_only(tool_config):
async def test_execute_sql_declaration_read_only(tool_settings):
"""Test BigQuery execute_sql tool declaration in read-only mode.
This test verifies that the execute_sql tool declaration reflects the
read-only capability.
"""
tool_name = "execute_sql"
tool = await get_tool(tool_name, tool_config)
tool = await get_tool(tool_name, tool_settings)
assert tool.name == tool_name
assert tool.description == textwrap.dedent("""\
Run a BigQuery or BigQuery ML SQL query in the project and return the result.
@@ -92,7 +92,7 @@ async def test_execute_sql_declaration_read_only(tool_config):
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
config (BigQueryToolConfig): The configuration for the tool.
settings (BigQueryToolConfig): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
@@ -127,7 +127,7 @@ async def test_execute_sql_declaration_read_only(tool_config):
@pytest.mark.parametrize(
("tool_config",),
("tool_settings",),
[
pytest.param(
BigQueryToolConfig(write_mode=WriteMode.ALLOWED),
@@ -136,14 +136,14 @@ async def test_execute_sql_declaration_read_only(tool_config):
],
)
@pytest.mark.asyncio
async def test_execute_sql_declaration_write(tool_config):
async def test_execute_sql_declaration_write(tool_settings):
"""Test BigQuery execute_sql tool declaration with all writes enabled.
This test verifies that the execute_sql tool declaration reflects the write
capability.
"""
tool_name = "execute_sql"
tool = await get_tool(tool_name, tool_config)
tool = await get_tool(tool_name, tool_settings)
assert tool.name == tool_name
assert tool.description == textwrap.dedent("""\
Run a BigQuery or BigQuery ML SQL query in the project and return the result.
@@ -153,7 +153,7 @@ async def test_execute_sql_declaration_write(tool_config):
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
config (BigQueryToolConfig): The configuration for the tool.
settings (BigQueryToolConfig): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
@@ -326,7 +326,7 @@ async def test_execute_sql_declaration_write(tool_config):
@pytest.mark.parametrize(
("tool_config",),
("tool_settings",),
[
pytest.param(
BigQueryToolConfig(write_mode=WriteMode.PROTECTED),
@@ -335,14 +335,14 @@ async def test_execute_sql_declaration_write(tool_config):
],
)
@pytest.mark.asyncio
async def test_execute_sql_declaration_protected_write(tool_config):
async def test_execute_sql_declaration_protected_write(tool_settings):
"""Test BigQuery execute_sql tool declaration with protected writes enabled.
This test verifies that the execute_sql tool declaration reflects the
protected write capability.
"""
tool_name = "execute_sql"
tool = await get_tool(tool_name, tool_config)
tool = await get_tool(tool_name, tool_settings)
assert tool.name == tool_name
assert tool.description == textwrap.dedent("""\
Run a BigQuery or BigQuery ML SQL query in the project and return the result.
@@ -352,7 +352,7 @@ async def test_execute_sql_declaration_protected_write(tool_config):
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
config (BigQueryToolConfig): The configuration for the tool.
settings (BigQueryToolConfig): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
@@ -530,7 +530,7 @@ def test_execute_sql_select_stmt(write_mode):
statement_type = "SELECT"
query_result = [{"num": 123}]
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=write_mode)
tool_settings = BigQueryToolConfig(write_mode=write_mode)
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = (
"test-bq-session-id",
@@ -550,7 +550,9 @@ def test_execute_sql_select_stmt(write_mode):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
@@ -586,7 +588,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
project = "my_project"
query_result = []
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
tool_settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
@@ -602,7 +604,9 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
@@ -638,7 +642,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
project = "my_project"
query_result = []
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED)
tool_settings = BigQueryToolConfig(write_mode=WriteMode.BLOCKED)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
@@ -654,7 +658,9 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {
"status": "ERROR",
"error_details": "Read-only mode only supports SELECT statements.",
@@ -693,7 +699,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type):
project = "my_project"
query_result = []
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
tool_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = (
"test-bq-session-id",
@@ -714,7 +720,9 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
@@ -756,7 +764,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
project = "my_project"
query_result = []
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
tool_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = (
"test-bq-session-id",
@@ -777,7 +785,9 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {
"status": "ERROR",
"error_details": (
@@ -808,7 +818,7 @@ def test_execute_sql_no_default_auth(
statement_type = "SELECT"
query_result = [{"num": 123}]
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=write_mode)
tool_settings = BigQueryToolConfig(write_mode=write_mode)
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = (
"test-bq-session-id",
@@ -830,7 +840,7 @@ def test_execute_sql_no_default_auth(
mock_query_and_wait.return_value = query_result
# Test the tool worked without invoking default auth
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(project, query, credentials, tool_settings, tool_context)
assert result == {"status": "SUCCESS", "rows": query_result}
mock_default_auth.assert_not_called()
@@ -959,7 +969,7 @@ def test_execute_sql_result_dtype(
project = "my_project"
statement_type = "SELECT"
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig()
tool_settings = BigQueryToolConfig()
tool_context = mock.create_autospec(ToolContext, instance=True)
# Simulate the result of query API
@@ -971,5 +981,5 @@ def test_execute_sql_result_dtype(
mock_query_and_wait.return_value = query_result
# Test the tool worked without invoking default auth
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = execute_sql(project, query, credentials, tool_settings, tool_context)
assert result == {"status": "SUCCESS", "rows": tool_result_rows}
@@ -15,8 +15,9 @@
from __future__ import annotations
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryTool
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.google_tool import GoogleTool
import pytest
@@ -30,12 +31,18 @@ async def test_bigquery_toolset_tools_default():
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(credentials_config=credentials_config)
toolset = BigQueryToolset(
credentials_config=credentials_config, bigquery_tool_config=None
)
# Verify that the tool config is initialized to default values.
assert isinstance(toolset._tool_settings, BigQueryToolConfig) # pylint: disable=protected-access
assert toolset._tool_settings.__dict__ == BigQueryToolConfig().__dict__ # pylint: disable=protected-access
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 5
assert all([isinstance(tool, BigQueryTool) for tool in tools])
assert all([isinstance(tool, GoogleTool) for tool in tools])
expected_tool_names = set([
"list_dataset_ids",
@@ -77,7 +84,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools):
assert tools is not None
assert len(tools) == len(selected_tools)
assert all([isinstance(tool, BigQueryTool) for tool in tools])
assert all([isinstance(tool, GoogleTool) for tool in tools])
expected_tool_names = set(selected_tools)
actual_tool_names = set([tool.name for tool in tools])
@@ -114,7 +121,7 @@ async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools):
assert tools is not None
assert len(tools) == len(returned_tools)
assert all([isinstance(tool, BigQueryTool) for tool in tools])
assert all([isinstance(tool, GoogleTool) for tool in tools])
expected_tool_names = set(returned_tools)
actual_tool_names = set([tool.name for tool in tools])
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,142 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import re
from unittest import mock
from google.adk.tools.spanner.client import get_spanner_client
from google.auth.exceptions import DefaultCredentialsError
from google.oauth2.credentials import Credentials
import pytest
def test_spanner_client_project():
"""Test spanner client project."""
# Trigger the spanner client creation
client = get_spanner_client(
project="test-gcp-project",
credentials=mock.create_autospec(Credentials, instance=True),
)
# Verify that the client has the desired project set
assert client.project == "test-gcp-project"
def test_spanner_client_project_set_explicit():
"""Test spanner client creation does not invoke default auth."""
# Let's simulate that no environment variables are set, so that any project
# set in there does not interfere with this test
with mock.patch.dict(os.environ, {}, clear=True):
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
# Simulate exception from default auth
mock_default_auth.side_effect = DefaultCredentialsError(
"Your default credentials were not found"
)
# Trigger the spanner client creation
client = get_spanner_client(
project="test-gcp-project",
credentials=mock.create_autospec(Credentials, instance=True),
)
# If we are here that already means client creation did not call default
# auth (otherwise we would have run into DefaultCredentialsError set
# above). For the sake of explicitness, trivially assert that the default
# auth was not called, and yet the project was set correctly
mock_default_auth.assert_not_called()
assert client.project == "test-gcp-project"
def test_spanner_client_project_set_with_default_auth():
"""Test spanner client creation invokes default auth to set the project."""
# Let's simulate that no environment variables are set, so that any project
# set in there does not interfere with this test
with mock.patch.dict(os.environ, {}, clear=True):
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
# Simulate credentials
mock_creds = mock.create_autospec(Credentials, instance=True)
# Simulate output of the default auth
mock_default_auth.return_value = (mock_creds, "test-gcp-project")
# Trigger the spanner client creation
client = get_spanner_client(
project=None,
credentials=mock_creds,
)
# Verify that default auth was called once to set the client project
mock_default_auth.assert_called_once()
assert client.project == "test-gcp-project"
def test_spanner_client_project_set_with_env():
"""Test spanner client creation sets the project from environment variable."""
# Let's simulate the project set in environment variables
with mock.patch.dict(
os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True
):
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
# Simulate exception from default auth
mock_default_auth.side_effect = DefaultCredentialsError(
"Your default credentials were not found"
)
# Trigger the spanner client creation
client = get_spanner_client(
project=None,
credentials=mock.create_autospec(Credentials, instance=True),
)
# If we are here that already means client creation did not call default
# auth (otherwise we would have run into DefaultCredentialsError set
# above). For the sake of explicitness, trivially assert that the default
# auth was not called, and yet the project was set correctly
mock_default_auth.assert_not_called()
assert client.project == "test-gcp-project"
def test_spanner_client_user_agent():
"""Test spanner client user agent."""
# Patch the Client constructor
with mock.patch(
"google.cloud.spanner.Client", autospec=True
) as mock_client_class:
# The mock instance that will be returned by spanner.Client()
mock_instance = mock_client_class.return_value
# The real spanner.Client instance has a `_client_info` attribute.
# We need to add it to our mock instance so that the user_agent can be set.
mock_instance._client_info = mock.Mock()
# Call the function that creates the client
client = get_spanner_client(
project="test-gcp-project",
credentials=mock.create_autospec(Credentials, instance=True),
)
# Verify that the Spanner Client was instantiated.
mock_client_class.assert_called_once_with(
project="test-gcp-project",
credentials=mock.ANY,
)
# Verify that the user_agent was set on the client instance.
# The client returned by get_spanner_client is the mock instance.
assert re.search(
r"adk-spanner-tool google-adk/([0-9A-Za-z._\-+/]+)",
client._client_info.user_agent,
)

Some files were not shown because too many files have changed in this diff Show More