From 1fc8d20ae88451b7ed764aa86c17c3cdfaffa1cf Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 12 Aug 2025 13:59:05 -0700 Subject: [PATCH] feat: add Spanner first-party toolset (breaking change to BigQueryTool, consolidating into generic GoogleTool) Spanner toolset support basic operations to interact with Spanner table metadata and query results. Consolidate BigQueryTool into generic GoogleTool, so that BigQueryToolset and SpannerToolset can share. PiperOrigin-RevId: 794259782 --- pyproject.toml | 1 + src/google/adk/tools/_google_credentials.py | 252 +++++++++ src/google/adk/tools/bigquery/__init__.py | 2 - .../tools/bigquery/bigquery_credentials.py | 215 +------- .../adk/tools/bigquery/bigquery_toolset.py | 12 +- .../adk/tools/bigquery/data_insights_tool.py | 6 +- src/google/adk/tools/bigquery/query_tool.py | 26 +- .../bigquery_tool.py => google_tool.py} | 39 +- src/google/adk/tools/spanner/__init__.py | 40 ++ src/google/adk/tools/spanner/client.py | 33 ++ src/google/adk/tools/spanner/metadata_tool.py | 503 ++++++++++++++++++ src/google/adk/tools/spanner/query_tool.py | 114 ++++ src/google/adk/tools/spanner/settings.py | 46 ++ .../adk/tools/spanner/spanner_credentials.py | 41 ++ .../adk/tools/spanner/spanner_toolset.py | 111 ++++ .../test_bigquery_data_insights_tool.py | 10 +- .../bigquery/test_bigquery_query_tool.py | 66 ++- .../tools/bigquery/test_bigquery_toolset.py | 17 +- tests/unittests/tools/spanner/__init__ | 13 + .../tools/spanner/test_spanner_client.py | 142 +++++ .../tools/spanner/test_spanner_credentials.py | 54 ++ .../spanner/test_spanner_tool_settings.py | 27 + .../tools/spanner/test_spanner_toolset.py | 185 +++++++ ...> test_base_google_credentials_manager.py} | 18 +- ...t_bigquery_tool.py => test_google_tool.py} | 63 ++- 25 files changed, 1716 insertions(+), 320 deletions(-) create mode 100644 src/google/adk/tools/_google_credentials.py rename src/google/adk/tools/{bigquery/bigquery_tool.py => google_tool.py} (77%) create mode 100644 src/google/adk/tools/spanner/__init__.py create mode 100644 src/google/adk/tools/spanner/client.py create mode 100644 src/google/adk/tools/spanner/metadata_tool.py create mode 100644 src/google/adk/tools/spanner/query_tool.py create mode 100644 src/google/adk/tools/spanner/settings.py create mode 100644 src/google/adk/tools/spanner/spanner_credentials.py create mode 100644 src/google/adk/tools/spanner/spanner_toolset.py create mode 100644 tests/unittests/tools/spanner/__init__ create mode 100644 tests/unittests/tools/spanner/test_spanner_client.py create mode 100644 tests/unittests/tools/spanner/test_spanner_credentials.py create mode 100644 tests/unittests/tools/spanner/test_spanner_tool_settings.py create mode 100644 tests/unittests/tools/spanner/test_spanner_toolset.py rename tests/unittests/tools/{bigquery/test_bigquery_credentials_manager.py => test_base_google_credentials_manager.py} (96%) rename tests/unittests/tools/{bigquery/test_bigquery_tool.py => test_google_tool.py} (84%) diff --git a/pyproject.toml b/pyproject.toml index 2d1414af..d00951f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "google-api-python-client>=2.157.0", # Google API client discovery "google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store. "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool + "google-cloud-spanner>=3.56.0", # For Spanner database "google-cloud-speech>=2.30.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service "google-genai>=1.21.1", # Google GenAI SDK diff --git a/src/google/adk/tools/_google_credentials.py b/src/google/adk/tools/_google_credentials.py new file mode 100644 index 00000000..c5e25a77 --- /dev/null +++ b/src/google/adk/tools/_google_credentials.py @@ -0,0 +1,252 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List +from typing import Optional + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +import google.auth.credentials +from google.auth.exceptions import RefreshError +from google.auth.transport.requests import Request +import google.oauth2.credentials +from pydantic import BaseModel +from pydantic import model_validator + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_credential import AuthCredentialTypes +from ..auth.auth_credential import OAuth2Auth +from ..auth.auth_tool import AuthConfig +from ..utils.feature_decorator import experimental +from .tool_context import ToolContext + + +@experimental +class BaseGoogleCredentialsConfig(BaseModel): + """Base Google Credentials Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ + + # Configure the model to allow arbitrary types like Credentials + model_config = {"arbitrary_types_allowed": True} + + credentials: Optional[google.auth.credentials.Credentials] = None + """The existing auth credentials to use. If set, this credential will be used + for every end user, end users don't need to be involved in the oauthflow. This + field is mutually exclusive with client_id, client_secret and scopes. + Don't set this field unless you are sure this credential has the permission to + access every end user's data. + + Example usage 1: When the agent is deployed in Google Cloud environment and + the service account (used as application default credentials) has access to + all the required Google Cloud resource. Setting this credential to allow user + to access the Google Cloud resource without end users going through oauth + flow. + + To get application default credential, use: `google.auth.default(...)`. See + more details in + https://cloud.google.com/docs/authentication/application-default-credentials. + + Example usage 2: When the agent wants to access the user's Google Cloud + resources using the service account key credentials. + + To load service account key credentials, use: + `google.auth.load_credentials_from_file(...)`. See more details in + https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + + When the deployed environment cannot provide a pre-existing credential, + consider setting below client_id, client_secret and scope for end users to go + through oauth flow, so that agent can access the user data. + """ + client_id: Optional[str] = None + """the oauth client ID to use.""" + client_secret: Optional[str] = None + """the oauth client secret to use.""" + scopes: Optional[List[str]] = None + """the scopes to use.""" + + _token_cache_key: Optional[str] = None + """The key to cache the token in the tool context.""" + + @model_validator(mode="after") + def __post_init__(self) -> BaseGoogleCredentialsConfig: + """Validate that either credentials or client ID/secret are provided.""" + if not self.credentials and (not self.client_id or not self.client_secret): + raise ValueError( + "Must provide either credentials or client_id and client_secret pair." + ) + if self.credentials and ( + self.client_id or self.client_secret or self.scopes + ): + raise ValueError( + "Cannot provide both existing credentials and" + " client_id/client_secret/scopes." + ) + + if self.credentials and isinstance( + self.credentials, google.oauth2.credentials.Credentials + ): + self.client_id = self.credentials.client_id + self.client_secret = self.credentials.client_secret + self.scopes = self.credentials.scopes + + return self + + +class GoogleCredentialsManager: + """Manages Google API credentials with automatic refresh and OAuth flow handling. + + This class centralizes credential management so multiple tools can share + the same authenticated session without duplicating OAuth logic. + """ + + def __init__( + self, + credentials_config: BaseGoogleCredentialsConfig, + ): + """Initialize the credential manager. + + Args: + credentials_config: Credentials containing client id and client secrete + or default credentials + """ + self.credentials_config = credentials_config + + async def get_valid_credentials( + self, tool_context: ToolContext + ) -> Optional[google.auth.credentials.Credentials]: + """Get valid credentials, handling refresh and OAuth flow as needed. + + Args: + tool_context: The tool context for OAuth flow and state management + + Returns: + Valid Credentials object, or None if OAuth flow is needed + """ + # First, try to get credentials from the tool context + creds_json = ( + tool_context.state.get(self.credentials_config._token_cache_key, None) + if self.credentials_config._token_cache_key + else None + ) + creds = ( + google.oauth2.credentials.Credentials.from_authorized_user_info( + json.loads(creds_json), self.credentials_config.scopes + ) + if creds_json + else None + ) + + # If credentails are empty use the default credential + if not creds: + creds = self.credentials_config.credentials + + # If non-oauth credentials are provided then use them as is. This helps + # in flows such as service account keys + if creds and not isinstance(creds, google.oauth2.credentials.Credentials): + return creds + + # Check if we have valid credentials + if creds and creds.valid: + return creds + + # Try to refresh expired credentials + if creds and creds.expired and creds.refresh_token: + try: + creds.refresh(Request()) + if creds.valid: + # Cache the refreshed credentials if token cache key is set + if self.credentials_config._token_cache_key: + tool_context.state[self.credentials_config._token_cache_key] = ( + creds.to_json() + ) + return creds + except RefreshError: + # Refresh failed, need to re-authenticate + pass + + # Need to perform OAuth flow + return await self._perform_oauth_flow(tool_context) + + async def _perform_oauth_flow( + self, tool_context: ToolContext + ) -> Optional[google.oauth2.credentials.Credentials]: + """Perform OAuth flow to get new credentials. + + Args: + tool_context: The tool context for OAuth flow + + Returns: + New Credentials object, or None if flow is in progress + """ + + # Create OAuth configuration + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://accounts.google.com/o/oauth2/auth", + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + scope: f"Access to {scope}" + for scope in self.credentials_config.scopes + }, + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=self.credentials_config.client_id, + client_secret=self.credentials_config.client_secret, + ), + ) + + # Check if OAuth response is available + auth_response = tool_context.get_auth_response( + AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential) + ) + + if auth_response: + # OAuth flow completed, create credentials + creds = google.oauth2.credentials.Credentials( + token=auth_response.oauth2.access_token, + refresh_token=auth_response.oauth2.refresh_token, + token_uri=auth_scheme.flows.authorizationCode.tokenUrl, + client_id=self.credentials_config.client_id, + client_secret=self.credentials_config.client_secret, + scopes=list(self.credentials_config.scopes), + ) + + # Cache the new credentials if token cache key is set + if self.credentials_config._token_cache_key: + tool_context.state[self.credentials_config._token_cache_key] = ( + creds.to_json() + ) + + return creds + else: + # Request OAuth flow + tool_context.request_credential( + AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential, + ) + ) + return None diff --git a/src/google/adk/tools/bigquery/__init__.py b/src/google/adk/tools/bigquery/__init__.py index 3db5a5ec..9e6b1166 100644 --- a/src/google/adk/tools/bigquery/__init__.py +++ b/src/google/adk/tools/bigquery/__init__.py @@ -28,11 +28,9 @@ definition. The rationales to have customized tool are: """ from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_tool import BigQueryTool from .bigquery_toolset import BigQueryToolset __all__ = [ - "BigQueryTool", "BigQueryToolset", "BigQueryCredentialsConfig", ] diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index d0f3abe0..00df6618 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -14,227 +14,28 @@ from __future__ import annotations -import json -from typing import List -from typing import Optional - -from fastapi.openapi.models import OAuth2 -from fastapi.openapi.models import OAuthFlowAuthorizationCode -from fastapi.openapi.models import OAuthFlows -import google.auth.credentials -from google.auth.exceptions import RefreshError -from google.auth.transport.requests import Request -import google.oauth2.credentials -from pydantic import BaseModel -from pydantic import model_validator - -from ...auth.auth_credential import AuthCredential -from ...auth.auth_credential import AuthCredentialTypes -from ...auth.auth_credential import OAuth2Auth -from ...auth.auth_tool import AuthConfig from ...utils.feature_decorator import experimental -from ..tool_context import ToolContext +from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @experimental -class BigQueryCredentialsConfig(BaseModel): - """Configuration for Google API tools (Experimental). +class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): + """BigQuery Credentials Configuration for Google API tools (Experimental). Please do not use this in production, as it may be deprecated later. """ - # Configure the model to allow arbitrary types like Credentials - model_config = {"arbitrary_types_allowed": True} - - credentials: Optional[google.auth.credentials.Credentials] = None - """The existing auth credentials to use. If set, this credential will be used - for every end user, end users don't need to be involved in the oauthflow. This - field is mutually exclusive with client_id, client_secret and scopes. - Don't set this field unless you are sure this credential has the permission to - access every end user's data. - - Example usage 1: When the agent is deployed in Google Cloud environment and - the service account (used as application default credentials) has access to - all the required BigQuery resource. Setting this credential to allow user to - access the BigQuery resource without end users going through oauth flow. - - To get application default credential, use: `google.auth.default(...)`. See more - details in https://cloud.google.com/docs/authentication/application-default-credentials. - - Example usage 2: When the agent wants to access the user's BigQuery resources - using the service account key credentials. - - To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`. - See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. - - When the deployed environment cannot provide a pre-existing credential, - consider setting below client_id, client_secret and scope for end users to go - through oauth flow, so that agent can access the user data. - """ - client_id: Optional[str] = None - """the oauth client ID to use.""" - client_secret: Optional[str] = None - """the oauth client secret to use.""" - scopes: Optional[List[str]] = None - """the scopes to use.""" - - @model_validator(mode="after") def __post_init__(self) -> BigQueryCredentialsConfig: - """Validate that either credentials or client ID/secret are provided.""" - if not self.credentials and (not self.client_id or not self.client_secret): - raise ValueError( - "Must provide either credentials or client_id and client_secret pair." - ) - if self.credentials and ( - self.client_id or self.client_secret or self.scopes - ): - raise ValueError( - "Cannot provide both existing credentials and" - " client_id/client_secret/scopes." - ) - - if self.credentials and isinstance( - self.credentials, google.oauth2.credentials.Credentials - ): - self.client_id = self.credentials.client_id - self.client_secret = self.credentials.client_secret - self.scopes = self.credentials.scopes + """Populate default scope if scopes is None.""" + super().__post_init__() if not self.scopes: self.scopes = BIGQUERY_DEFAULT_SCOPE + # Set the token cache key + self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY + return self - - -class BigQueryCredentialsManager: - """Manages Google API credentials with automatic refresh and OAuth flow handling. - - This class centralizes credential management so multiple tools can share - the same authenticated session without duplicating OAuth logic. - """ - - def __init__(self, credentials_config: BigQueryCredentialsConfig): - """Initialize the credential manager. - - Args: - credentials_config: Credentials containing client id and client secrete - or default credentials - """ - self.credentials_config = credentials_config - - async def get_valid_credentials( - self, tool_context: ToolContext - ) -> Optional[google.auth.credentials.Credentials]: - """Get valid credentials, handling refresh and OAuth flow as needed. - - Args: - tool_context: The tool context for OAuth flow and state management - - Returns: - Valid Credentials object, or None if OAuth flow is needed - """ - # First, try to get credentials from the tool context - creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) - creds = ( - google.oauth2.credentials.Credentials.from_authorized_user_info( - json.loads(creds_json), self.credentials_config.scopes - ) - if creds_json - else None - ) - - # If credentails are empty use the default credential - if not creds: - creds = self.credentials_config.credentials - - # If non-oauth credentials are provided then use them as is. This helps - # in flows such as service account keys - if creds and not isinstance(creds, google.oauth2.credentials.Credentials): - return creds - - # Check if we have valid credentials - if creds and creds.valid: - return creds - - # Try to refresh expired credentials - if creds and creds.expired and creds.refresh_token: - try: - creds.refresh(Request()) - if creds.valid: - # Cache the refreshed credentials - tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json() - return creds - except RefreshError: - # Refresh failed, need to re-authenticate - pass - - # Need to perform OAuth flow - return await self._perform_oauth_flow(tool_context) - - async def _perform_oauth_flow( - self, tool_context: ToolContext - ) -> Optional[google.oauth2.credentials.Credentials]: - """Perform OAuth flow to get new credentials. - - Args: - tool_context: The tool context for OAuth flow - required_scopes: Set of required OAuth scopes - - Returns: - New Credentials object, or None if flow is in progress - """ - - # Create OAuth configuration - auth_scheme = OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://accounts.google.com/o/oauth2/auth", - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - scope: f"Access to {scope}" - for scope in self.credentials_config.scopes - }, - ) - ) - ) - - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=self.credentials_config.client_id, - client_secret=self.credentials_config.client_secret, - ), - ) - - # Check if OAuth response is available - auth_response = tool_context.get_auth_response( - AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential) - ) - - if auth_response: - # OAuth flow completed, create credentials - creds = google.oauth2.credentials.Credentials( - token=auth_response.oauth2.access_token, - refresh_token=auth_response.oauth2.refresh_token, - token_uri=auth_scheme.flows.authorizationCode.tokenUrl, - client_id=self.credentials_config.client_id, - client_secret=self.credentials_config.client_secret, - scopes=list(self.credentials_config.scopes), - ) - - # Cache the new credentials - tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json() - - return creds - else: - # Request OAuth flow - tool_context.request_credential( - AuthConfig( - auth_scheme=auth_scheme, - raw_auth_credential=auth_credential, - ) - ) - return None diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 313cf499..4b84a270 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -26,9 +26,9 @@ from . import query_tool from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool from ...utils.feature_decorator import experimental from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_tool import BigQueryTool from .config import BigQueryToolConfig @@ -45,7 +45,9 @@ class BigQueryToolset(BaseToolset): ): self.tool_filter = tool_filter self._credentials_config = credentials_config - self._tool_config = bigquery_tool_config + self._tool_settings = ( + bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig() + ) def _is_tool_selected( self, tool: BaseTool, readonly_context: ReadonlyContext @@ -67,17 +69,17 @@ class BigQueryToolset(BaseToolset): ) -> List[BaseTool]: """Get tools from the toolset.""" all_tools = [ - BigQueryTool( + GoogleTool( func=func, credentials_config=self._credentials_config, - bigquery_tool_config=self._tool_config, + tool_settings=self._tool_settings, ) for func in [ metadata_tool.get_dataset_info, metadata_tool.get_table_info, metadata_tool.list_dataset_ids, metadata_tool.list_table_ids, - query_tool.get_execute_sql(self._tool_config), + query_tool.get_execute_sql(self._tool_settings), ] ] diff --git a/src/google/adk/tools/bigquery/data_insights_tool.py b/src/google/adk/tools/bigquery/data_insights_tool.py index a2fdca08..2af2249b 100644 --- a/src/google/adk/tools/bigquery/data_insights_tool.py +++ b/src/google/adk/tools/bigquery/data_insights_tool.py @@ -30,7 +30,7 @@ def ask_data_insights( user_query_with_context: str, table_references: List[Dict[str, str]], credentials: Credentials, - config: BigQueryToolConfig, + settings: BigQueryToolConfig, ) -> Dict[str, Any]: """Answers questions about structured data in BigQuery tables using natural language. @@ -53,7 +53,7 @@ def ask_data_insights( table_references (List[Dict[str, str]]): A list of dictionaries, each specifying a BigQuery table to be used as context for the question. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. Returns: A dictionary with two keys: @@ -135,7 +135,7 @@ def ask_data_insights( } resp = _get_stream( - ca_url, ca_payload, headers, config.max_query_result_rows + ca_url, ca_payload, headers, settings.max_query_result_rows ) except Exception as ex: # pylint: disable=broad-except return { diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index f989a9f7..5ceebc4c 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -36,7 +36,7 @@ def execute_sql( project_id: str, query: str, credentials: Credentials, - config: BigQueryToolConfig, + settings: BigQueryToolConfig, tool_context: ToolContext, ) -> dict: """Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -46,7 +46,7 @@ def execute_sql( executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -89,7 +89,7 @@ def execute_sql( # BigQuery connection properties where applicable bq_connection_properties = None - if not config or config.write_mode == WriteMode.BLOCKED: + if not settings or settings.write_mode == WriteMode.BLOCKED: dry_run_query_job = bq_client.query( query, project=project_id, @@ -100,7 +100,7 @@ def execute_sql( "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } - elif config.write_mode == WriteMode.PROTECTED: + elif settings.write_mode == WriteMode.PROTECTED: # In protected write mode, write operation only to a temporary artifact is # allowed. This artifact must have been created in a BigQuery session. In # such a scenario the session info (session id and the anonymous dataset @@ -161,7 +161,7 @@ def execute_sql( query, job_config=job_config, project=project_id, - max_results=config.max_query_result_rows, + max_results=settings.max_query_result_rows, ) rows = [] for row in row_iterator: @@ -177,8 +177,8 @@ def execute_sql( result = {"status": "SUCCESS", "rows": rows} if ( - config.max_query_result_rows is not None - and len(rows) == config.max_query_result_rows + settings.max_query_result_rows is not None + and len(rows) == settings.max_query_result_rows ): result["result_is_likely_truncated"] = True return result @@ -462,19 +462,19 @@ _execute_sql_protecetd_write_examples = """ """ -def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]: - """Get the execute_sql tool customized as per the given tool config. +def get_execute_sql(settings: BigQueryToolConfig) -> Callable[..., dict]: + """Get the execute_sql tool customized as per the given tool settings. Args: - config: BigQuery tool configuration indicating the behavior of the + settings: BigQuery tool settings indicating the behavior of the execute_sql tool. Returns: callable[..., dict]: A version of the execute_sql tool respecting the tool - config. + settings. """ - if not config or config.write_mode == WriteMode.BLOCKED: + if not settings or settings.write_mode == WriteMode.BLOCKED: return execute_sql # Create a new function object using the original function's code and globals. @@ -495,7 +495,7 @@ def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]: functools.update_wrapper(execute_sql_wrapper, execute_sql) # Now, set the new docstring - if config.write_mode == WriteMode.PROTECTED: + if settings.write_mode == WriteMode.PROTECTED: examples = _execute_sql_protecetd_write_examples else: examples = _execute_sql_write_examples diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/google_tool.py similarity index 77% rename from src/google/adk/tools/bigquery/bigquery_tool.py rename to src/google/adk/tools/google_tool.py index 0b231edb..9776fa0f 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/google_tool.py @@ -20,19 +20,19 @@ from typing import Callable from typing import Optional from google.auth.credentials import Credentials +from pydantic import BaseModel from typing_extensions import override -from ...utils.feature_decorator import experimental -from ..function_tool import FunctionTool -from ..tool_context import ToolContext -from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_credentials import BigQueryCredentialsManager -from .config import BigQueryToolConfig +from ..utils.feature_decorator import experimental +from ._google_credentials import BaseGoogleCredentialsConfig +from ._google_credentials import GoogleCredentialsManager +from .function_tool import FunctionTool +from .tool_context import ToolContext @experimental -class BigQueryTool(FunctionTool): - """GoogleApiTool class for tools that call Google APIs. +class GoogleTool(FunctionTool): + """GoogleTool class for tools that call Google APIs. This class is for developers to handcraft customized Google API tools rather than auto generate Google API tools based on API specs. @@ -46,8 +46,8 @@ class BigQueryTool(FunctionTool): self, func: Callable[..., Any], *, - credentials_config: Optional[BigQueryCredentialsConfig] = None, - bigquery_tool_config: Optional[BigQueryToolConfig] = None, + credentials_config: Optional[BaseGoogleCredentialsConfig] = None, + tool_settings: Optional[BaseModel] = None, ): """Initialize the Google API tool. @@ -56,18 +56,18 @@ class BigQueryTool(FunctionTool): 'credential" parameter credentials_config: credentials config used to call Google API. If None, then we don't hanlde the auth logic + tool_settings: Tool-specific settings. This settings should be provided + by each toolset that uses this class to create customized tools. """ super().__init__(func=func) self._ignore_params.append("credentials") - self._ignore_params.append("config") + self._ignore_params.append("settings") self._credentials_manager = ( - BigQueryCredentialsManager(credentials_config) + GoogleCredentialsManager(credentials_config) if credentials_config else None ) - self._tool_config = ( - bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig() - ) + self._tool_settings = tool_settings @override async def run_async( @@ -96,7 +96,7 @@ class BigQueryTool(FunctionTool): # Execute the tool's specific logic with valid credentials return await self._run_async_with_credential( - credentials, self._tool_config, args, tool_context + credentials, self._tool_settings, args, tool_context ) except Exception as ex: @@ -108,7 +108,7 @@ class BigQueryTool(FunctionTool): async def _run_async_with_credential( self, credentials: Credentials, - tool_config: BigQueryToolConfig, + tool_settings: BaseModel, args: dict[str, Any], tool_context: ToolContext, ) -> Any: @@ -116,6 +116,7 @@ class BigQueryTool(FunctionTool): Args: credentials: Valid Google OAuth credentials + tool_settings: Tool settings args: Arguments passed to the tool tool_context: Tool execution context @@ -126,6 +127,6 @@ class BigQueryTool(FunctionTool): signature = inspect.signature(self.func) if "credentials" in signature.parameters: args_to_call["credentials"] = credentials - if "config" in signature.parameters: - args_to_call["config"] = tool_config + if "settings" in signature.parameters: + args_to_call["settings"] = tool_settings return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/spanner/__init__.py b/src/google/adk/tools/spanner/__init__.py new file mode 100644 index 00000000..30686b96 --- /dev/null +++ b/src/google/adk/tools/spanner/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spanner Tools (Experimental). + +Spanner Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: + +1. A dedicated Spanner toolset to provide an easier, integrated way to interact +with Spanner database and tables for building AI Agent applications quickly. +2. We want to provide more high-level tools like Search, ML.Predict, and Graph +etc. +3. We want to provide extra access guardrails and controls in those tools. +For example, execute_sql can't arbitrarily mutate existing data. +4. We want to provide Spanner best practices and knowledge assistants for ad-hoc +analytics queries. +5. Use Spanner Toolset for more customization and control to interact with +Spanner database and tables. +""" + +from . import spanner_credentials +from .spanner_toolset import SpannerToolset + +SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig +__all__ = [ + "SpannerToolset", + "SpannerCredentialsConfig", +] diff --git a/src/google/adk/tools/spanner/client.py b/src/google/adk/tools/spanner/client.py new file mode 100644 index 00000000..aecba9e9 --- /dev/null +++ b/src/google/adk/tools/spanner/client.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.auth.credentials import Credentials +from google.cloud import spanner + +from ... import version + +USER_AGENT = f"adk-spanner-tool google-adk/{version.__version__}" + + +def get_spanner_client( + *, project: str, credentials: Credentials +) -> spanner.Client: + """Get a Spanner client.""" + + spanner_client = spanner.Client(project=project, credentials=credentials) + spanner_client._client_info.user_agent = USER_AGENT + + return spanner_client diff --git a/src/google/adk/tools/spanner/metadata_tool.py b/src/google/adk/tools/spanner/metadata_tool.py new file mode 100644 index 00000000..704df978 --- /dev/null +++ b/src/google/adk/tools/spanner/metadata_tool.py @@ -0,0 +1,503 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json + +from google.auth.credentials import Credentials +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_v1 import param_types as spanner_param_types + +from . import client + + +def list_table_names( + project_id: str, + instance_id: str, + database_id: str, + credentials: Credentials, + named_schema: str = "", +) -> dict: + """List tables within the database. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + credentials (Credentials): The credentials to use for the request. + named_schema (str): The named schema to list tables in. Default is empty + string "" to search for tables in the default schema of the database. + + Returns: + dict: Dictionary with a list of the Spanner table names. + + Examples: + >>> list_tables("my_project", "my_instance", "my_database") + { + "status": "SUCCESS", + "results": [ + "table_1", + "table_2" + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + tables = [] + named_schema = named_schema if named_schema else "_default" + for table in database.list_tables(schema=named_schema): + tables.append(table.table_id) + + return {"status": "SUCCESS", "results": tables} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_table_schema( + project_id: str, + instance_id: str, + database_id: str, + table_name: str, + credentials: Credentials, + named_schema: str = "", +) -> dict: + """Get schema information about a Spanner table. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + table_id (str): The Spanner table id. + credentials (Credentials): The credentials to use for the request. + named_schema (str): The named schema to list tables in. Default is empty + string "" to search for tables in the default schema of the database. + + Returns: + dict: Dictionary with the Spanner table schema information. + + Examples: + >>> get_table_schema("my_project", "my_instance", "my_database", + ... "my_table") + { + "status": "SUCCESS", + "results": + { + 'colA': { + 'SPANNER_TYPE': 'STRING(1024)', + 'TABLE_SCHEMA': '', + 'ORDINAL_POSITION': 1, + 'COLUMN_DEFAULT': None, + 'IS_NULLABLE': 'NO', + 'IS_GENERATED': 'NEVER', + 'GENERATION_EXPRESSION': None, + 'IS_STORED': None, + 'KEY_COLUMN_USAGE': { # This part is added if it's a key column + 'CONSTRAINT_NAME': 'PK_Table1', + 'ORDINAL_POSITION': 1, + 'POSITION_IN_UNIQUE_CONSTRAINT': None + } + }, + 'colB': { ... }, + ... + } + """ + + columns_query = """ + SELECT + COLUMN_NAME, + TABLE_SCHEMA, + SPANNER_TYPE, + ORDINAL_POSITION, + COLUMN_DEFAULT, + IS_NULLABLE, + IS_GENERATED, + GENERATION_EXPRESSION, + IS_STORED + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_NAME = @table_name + AND TABLE_SCHEMA = @named_schema + ORDER BY + ORDINAL_POSITION + """ + + key_column_usage_query = """ + SELECT + COLUMN_NAME, + CONSTRAINT_NAME, + ORDINAL_POSITION, + POSITION_IN_UNIQUE_CONSTRAINT + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE + WHERE + TABLE_NAME = @table_name + AND TABLE_SCHEMA = @named_schema + """ + params = {"table_name": table_name, "named_schema": named_schema} + param_types = { + "table_name": spanner_param_types.STRING, + "named_schema": spanner_param_types.STRING, + } + + schema = {} + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported", + } + + with database.snapshot(multi_use=True) as snapshot: + result_set = snapshot.execute_sql( + columns_query, params=params, param_types=param_types + ) + for row in result_set: + ( + column_name, + table_schema, + spanner_type, + ordinal_position, + column_default, + is_nullable, + is_generated, + generation_expression, + is_stored, + ) = row + column_metadata = { + "SPANNER_TYPE": spanner_type, + "TABLE_SCHEMA": table_schema, + "ORDINAL_POSITION": ordinal_position, + "COLUMN_DEFAULT": column_default, + "IS_NULLABLE": is_nullable, + "IS_GENERATED": is_generated, + "GENERATION_EXPRESSION": generation_expression, + "IS_STORED": is_stored, + } + schema[column_name] = column_metadata + + key_column_result_set = snapshot.execute_sql( + key_column_usage_query, params=params, param_types=param_types + ) + for row in key_column_result_set: + ( + column_name, + constraint_name, + ordinal_position, + position_in_unique_constraint, + ) = row + + key_column_properties = { + "CONSTRAINT_NAME": constraint_name, + "ORDINAL_POSITION": ordinal_position, + "POSITION_IN_UNIQUE_CONSTRAINT": position_in_unique_constraint, + } + # Attach key column info to the existing column schema entry + if column_name in schema: + schema[column_name]["KEY_COLUMN_USAGE"] = key_column_properties + + try: + json.dumps(schema) + except: + schema = str(schema) + + return {"status": "SUCCESS", "results": schema} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_table_indexes( + project_id: str, + instance_id: str, + database_id: str, + table_id: str, + credentials: Credentials, +) -> dict: + """Get index information about a Spanner table. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + table_id (str): The Spanner table id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary with a list of the Spanner table index information. + + Examples: + >>> list_table_indexes("my_project", "my_instance", "my_database", + ... "my_table") + { + "status": "SUCCESS", + "results": [ + { + 'INDEX_NAME': 'IDX_MyTable_Column_FC70CD41F3A5FD3A', + 'TABLE_SCHEMA': '', + 'INDEX_TYPE': 'INDEX', + 'PARENT_TABLE_NAME': '', + 'IS_UNIQUE': False, + 'IS_NULL_FILTERED': False, + 'INDEX_STATE': 'READ_WRITE' + }, + { + 'INDEX_NAME': 'PRIMARY_KEY', + 'TABLE_SCHEMA': '', + 'INDEX_TYPE': 'PRIMARY_KEY', + 'PARENT_TABLE_NAME': '', + 'IS_UNIQUE': True, + 'IS_NULL_FILTERED': False, + 'INDEX_STATE': None + } + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + # Using query parameters is best practice to prevent SQL injection, + # even if table_id is typically from a controlled source here. + sql_query = ( + "SELECT INDEX_NAME, TABLE_SCHEMA, INDEX_TYPE," + " PARENT_TABLE_NAME, IS_UNIQUE, IS_NULL_FILTERED, INDEX_STATE " + "FROM INFORMATION_SCHEMA.INDEXES " + "WHERE TABLE_NAME = @table_id " # Use query parameter + ) + params = {"table_id": table_id} + param_types = {"table_id": spanner_param_types.STRING} + + indexes = [] + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql( + sql_query, params=params, param_types=param_types + ) + for row in result_set: + index_info = {} + index_info["INDEX_NAME"] = row[0] + index_info["TABLE_SCHEMA"] = row[1] + index_info["INDEX_TYPE"] = row[2] + index_info["PARENT_TABLE_NAME"] = row[3] + index_info["IS_UNIQUE"] = row[4] + index_info["IS_NULL_FILTERED"] = row[5] + index_info["INDEX_STATE"] = row[6] + + try: + json.dumps(index_info) + except: + index_info = str(index_info) + + indexes.append(index_info) + + return {"status": "SUCCESS", "results": indexes} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_table_index_columns( + project_id: str, + instance_id: str, + database_id: str, + table_id: str, + credentials: Credentials, +) -> dict: + """Get the columns in each index of a Spanner table. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + table_id (str): The Spanner table id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary with a list of Spanner table index column + information. + + Examples: + >>> get_table_index_columns("my_project", "my_instance", "my_database", + ... "my_table") + { + "status": "SUCCESS", + "results": [ + { + 'INDEX_NAME': 'PRIMARY_KEY', + 'TABLE_SCHEMA': '', + 'COLUMN_NAME': 'ColumnKey1', + 'ORDINAL_POSITION': 1, + 'IS_NULLABLE': 'NO', + 'SPANNER_TYPE': 'STRING(MAX)' + }, + { + 'INDEX_NAME': 'PRIMARY_KEY', + 'TABLE_SCHEMA': '', + 'COLUMN_NAME': 'ColumnKey2', + 'ORDINAL_POSITION': 2, + 'IS_NULLABLE': 'NO', + 'SPANNER_TYPE': 'INT64' + }, + { + 'INDEX_NAME': 'IDX_MyTable_Column_FC70CD41F3A5FD3A', + 'TABLE_SCHEMA': '', + 'COLUMN_NAME': 'Column', + 'ORDINAL_POSITION': 3, + 'IS_NULLABLE': 'NO', + 'SPANNER_TYPE': 'STRING(MAX)' + } + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + sql_query = ( + "SELECT INDEX_NAME, TABLE_SCHEMA, COLUMN_NAME," + " ORDINAL_POSITION, IS_NULLABLE, SPANNER_TYPE " + "FROM INFORMATION_SCHEMA.INDEX_COLUMNS " + "WHERE TABLE_NAME = @table_id " # Use query parameter + ) + params = {"table_id": table_id} + param_types = {"table_id": spanner_param_types.STRING} + + index_columns = [] + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql( + sql_query, params=params, param_types=param_types + ) + for row in result_set: + index_column_info = {} + index_column_info["INDEX_NAME"] = row[0] + index_column_info["TABLE_SCHEMA"] = row[1] + index_column_info["COLUMN_NAME"] = row[2] + index_column_info["ORDINAL_POSITION"] = row[3] + index_column_info["IS_NULLABLE"] = row[4] + index_column_info["SPANNER_TYPE"] = row[5] + + try: + json.dumps(index_column_info) + except: + index_column_info = str(index_column_info) + + index_columns.append(index_column_info) + + return {"status": "SUCCESS", "results": index_columns} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_named_schemas( + project_id: str, + instance_id: str, + database_id: str, + credentials: Credentials, +) -> dict: + """Get the named schemas in the Spanner database. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary with a list of named schemas information in the Spanner + database. + + Examples: + >>> list_named_schemas("my_project", "my_instance", "my_database") + { + "status": "SUCCESS", + "results": [ + "schema_1", + "schema_2" + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + sql_query = """ + SELECT + SCHEMA_NAME + FROM + INFORMATION_SCHEMA.SCHEMATA + WHERE + SCHEMA_NAME NOT IN ('', 'INFORMATION_SCHEMA', 'SPANNER_SYS'); + """ + + named_schemas = [] + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql(sql_query) + for row in result_set: + named_schemas.append(row[0]) + + return {"status": "SUCCESS", "results": named_schemas} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/spanner/query_tool.py b/src/google/adk/tools/spanner/query_tool.py new file mode 100644 index 00000000..e317a0ce --- /dev/null +++ b/src/google/adk/tools/spanner/query_tool.py @@ -0,0 +1,114 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json + +from google.auth.credentials import Credentials +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect + +from . import client +from ..tool_context import ToolContext +from .settings import SpannerToolSettings + +DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 + + +def execute_sql( + project_id: str, + instance_id: str, + database_id: str, + query: str, + credentials: Credentials, + settings: SpannerToolSettings, + tool_context: ToolContext, +) -> dict: + """Run a Spanner Read-Only query in the spanner database and return the result. + + Args: + project_id (str): The GCP project id in which the spanner database + resides. + instance_id (str): The instance id of the spanner database. + database_id (str): The database id of the spanner database. + query (str): The Spanner SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The settings for the tool. + tool_context (ToolContext): The context for the tool. + + Returns: + dict: Dictionary with the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + + Examples: + Fetch data or insights from a table: + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) AS count FROM my_table") + { + "status": "SUCCESS", + "rows": [ + [100] + ] + } + + Note: + This is running with Read-Only Transaction for query that only read data. + """ + + try: + # Get Spanner client + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql(query) + rows = [] + counter = ( + settings.max_executed_query_result_rows + if settings and settings.max_executed_query_result_rows > 0 + else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS + ) + for row in result_set: + try: + # if the json serialization of the row succeeds, use it as is + json.dumps(row) + except: + row = str(row) + + rows.append(row) + counter -= 1 + if counter <= 0: + break + + result = {"status": "SUCCESS", "rows": rows} + if counter <= 0: + result["result_is_likely_truncated"] = True + return result + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py new file mode 100644 index 00000000..5d097258 --- /dev/null +++ b/src/google/adk/tools/spanner/settings.py @@ -0,0 +1,46 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum +from typing import List + +from pydantic import BaseModel + +from ...utils.feature_decorator import experimental + + +class Capabilities(Enum): + """Capabilities indicating what type of operation tools are allowed to be performed on Spanner.""" + + DATA_READ = 'data_read' + """Read only data operations tools are allowed.""" + + +@experimental('Tool settings defaults may have breaking change in the future.') +class SpannerToolSettings(BaseModel): + """Settings for Spanner tools.""" + + capabilities: List[Capabilities] = [ + Capabilities.DATA_READ, + ] + """Allowed capabilities for Spanner tools. + + By default, the tool will allow only read operations. This behaviour may + change in future versions. + """ + + max_executed_query_result_rows: int = 50 + """Maximum number of rows to return from a query result.""" diff --git a/src/google/adk/tools/spanner/spanner_credentials.py b/src/google/adk/tools/spanner/spanner_credentials.py new file mode 100644 index 00000000..69279a49 --- /dev/null +++ b/src/google/adk/tools/spanner/spanner_credentials.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from ...utils.feature_decorator import experimental +from .._google_credentials import BaseGoogleCredentialsConfig + +SPANNER_TOKEN_CACHE_KEY = "spanner_token_cache" +SPANNER_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/spanner.data"] + + +@experimental +class SpannerCredentialsConfig(BaseGoogleCredentialsConfig): + """Spanner Credentials Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ + + def __post_init__(self) -> SpannerCredentialsConfig: + """Populate default scope if scopes is None.""" + super().__post_init__() + + if not self.scopes: + self.scopes = SPANNER_DEFAULT_SCOPE + + # Set the token cache key + self._token_cache_key = SPANNER_TOKEN_CACHE_KEY + + return self diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py new file mode 100644 index 00000000..97dfa9a8 --- /dev/null +++ b/src/google/adk/tools/spanner/spanner_toolset.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from google.adk.agents.readonly_context import ReadonlyContext +from typing_extensions import override + +from . import metadata_tool +from . import query_tool +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool +from ...utils.feature_decorator import experimental +from .settings import Capabilities +from .settings import SpannerToolSettings +from .spanner_credentials import SpannerCredentialsConfig + + +@experimental +class SpannerToolset(BaseToolset): + """Spanner Toolset contains tools for interacting with Spanner data, database and table information.""" + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + credentials_config: Optional[SpannerCredentialsConfig] = None, + spanner_tool_settings: Optional[SpannerToolSettings] = None, + ): + self.tool_filter = tool_filter + self._credentials_config = credentials_config + self._tool_settings = ( + spanner_tool_settings + if spanner_tool_settings + else SpannerToolSettings() + ) + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if self.tool_filter is None: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + """Get tools from the toolset.""" + all_tools = [ + GoogleTool( + func=func, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + for func in [ + # Metadata tools + metadata_tool.list_table_names, + metadata_tool.list_table_indexes, + metadata_tool.list_table_index_columns, + metadata_tool.list_named_schemas, + metadata_tool.get_table_schema, + ] + ] + + # Query tools + if ( + self._tool_settings + and Capabilities.DATA_READ in self._tool_settings.capabilities + ): + all_tools.append( + GoogleTool( + func=query_tool.execute_sql, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + ) + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + pass diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py index bf188ba8..2c52d1e6 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py @@ -74,8 +74,8 @@ def test_ask_data_insights_success(mock_get_stream): # 2. Create mock inputs for the function call mock_creds = mock.Mock() mock_creds.token = "fake-token" - mock_config = mock.Mock() - mock_config.max_query_result_rows = 100 + mock_settings = mock.Mock() + mock_settings.max_query_result_rows = 100 # 3. Call the function under test result = data_insights_tool.ask_data_insights( @@ -83,7 +83,7 @@ def test_ask_data_insights_success(mock_get_stream): user_query_with_context="test query", table_references=[], credentials=mock_creds, - config=mock_config, + settings=mock_settings, ) # 4. Assert the results are as expected @@ -101,7 +101,7 @@ def test_ask_data_insights_handles_exception(mock_get_stream): # 2. Create mock inputs mock_creds = mock.Mock() mock_creds.token = "fake-token" - mock_config = mock.Mock() + mock_settings = mock.Mock() # 3. Call the function result = data_insights_tool.ask_data_insights( @@ -109,7 +109,7 @@ def test_ask_data_insights_handles_exception(mock_get_stream): user_query_with_context="test query", table_references=[], credentials=mock_creds, - config=mock_config, + settings=mock_settings, ) # 4. Assert that the error was caught and formatted correctly diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index f0e673da..fe76a309 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -37,7 +37,7 @@ import pytest async def get_tool( - name: str, tool_config: Optional[BigQueryToolConfig] = None + name: str, tool_settings: Optional[BigQueryToolConfig] = None ) -> BaseTool: """Get a tool from BigQuery toolset. @@ -54,7 +54,7 @@ async def get_tool( toolset = BigQueryToolset( credentials_config=credentials_config, tool_filter=[name], - bigquery_tool_config=tool_config, + bigquery_tool_config=tool_settings, ) tools = await toolset.get_tools() @@ -64,7 +64,7 @@ async def get_tool( @pytest.mark.parametrize( - ("tool_config",), + ("tool_settings",), [ pytest.param(None, id="no-config"), pytest.param(BigQueryToolConfig(), id="default-config"), @@ -75,14 +75,14 @@ async def get_tool( ], ) @pytest.mark.asyncio -async def test_execute_sql_declaration_read_only(tool_config): +async def test_execute_sql_declaration_read_only(tool_settings): """Test BigQuery execute_sql tool declaration in read-only mode. This test verifies that the execute_sql tool declaration reflects the read-only capability. """ tool_name = "execute_sql" - tool = await get_tool(tool_name, tool_config) + tool = await get_tool(tool_name, tool_settings) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -92,7 +92,7 @@ async def test_execute_sql_declaration_read_only(tool_config): executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -127,7 +127,7 @@ async def test_execute_sql_declaration_read_only(tool_config): @pytest.mark.parametrize( - ("tool_config",), + ("tool_settings",), [ pytest.param( BigQueryToolConfig(write_mode=WriteMode.ALLOWED), @@ -136,14 +136,14 @@ async def test_execute_sql_declaration_read_only(tool_config): ], ) @pytest.mark.asyncio -async def test_execute_sql_declaration_write(tool_config): +async def test_execute_sql_declaration_write(tool_settings): """Test BigQuery execute_sql tool declaration with all writes enabled. This test verifies that the execute_sql tool declaration reflects the write capability. """ tool_name = "execute_sql" - tool = await get_tool(tool_name, tool_config) + tool = await get_tool(tool_name, tool_settings) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -153,7 +153,7 @@ async def test_execute_sql_declaration_write(tool_config): executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -326,7 +326,7 @@ async def test_execute_sql_declaration_write(tool_config): @pytest.mark.parametrize( - ("tool_config",), + ("tool_settings",), [ pytest.param( BigQueryToolConfig(write_mode=WriteMode.PROTECTED), @@ -335,14 +335,14 @@ async def test_execute_sql_declaration_write(tool_config): ], ) @pytest.mark.asyncio -async def test_execute_sql_declaration_protected_write(tool_config): +async def test_execute_sql_declaration_protected_write(tool_settings): """Test BigQuery execute_sql tool declaration with protected writes enabled. This test verifies that the execute_sql tool declaration reflects the protected write capability. """ tool_name = "execute_sql" - tool = await get_tool(tool_name, tool_config) + tool = await get_tool(tool_name, tool_settings) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -352,7 +352,7 @@ async def test_execute_sql_declaration_protected_write(tool_config): executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -530,7 +530,7 @@ def test_execute_sql_select_stmt(write_mode): statement_type = "SELECT" query_result = [{"num": 123}] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=write_mode) + tool_settings = BigQueryToolConfig(write_mode=write_mode) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -550,7 +550,9 @@ def test_execute_sql_select_stmt(write_mode): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -586,7 +588,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) tool_context = mock.create_autospec(ToolContext, instance=True) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: @@ -602,7 +604,9 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -638,7 +642,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) tool_context = mock.create_autospec(ToolContext, instance=True) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: @@ -654,7 +658,9 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == { "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", @@ -693,7 +699,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type): project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -714,7 +720,9 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -756,7 +764,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target( project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -777,7 +785,9 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target( bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == { "status": "ERROR", "error_details": ( @@ -808,7 +818,7 @@ def test_execute_sql_no_default_auth( statement_type = "SELECT" query_result = [{"num": 123}] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=write_mode) + tool_settings = BigQueryToolConfig(write_mode=write_mode) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -830,7 +840,7 @@ def test_execute_sql_no_default_auth( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql(project, query, credentials, tool_settings, tool_context) assert result == {"status": "SUCCESS", "rows": query_result} mock_default_auth.assert_not_called() @@ -959,7 +969,7 @@ def test_execute_sql_result_dtype( project = "my_project" statement_type = "SELECT" credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig() + tool_settings = BigQueryToolConfig() tool_context = mock.create_autospec(ToolContext, instance=True) # Simulate the result of query API @@ -971,5 +981,5 @@ def test_execute_sql_result_dtype( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql(project, query, credentials, tool_settings, tool_context) assert result == {"status": "SUCCESS", "rows": tool_result_rows} diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index 4129dc51..8f21e1be 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -15,8 +15,9 @@ from __future__ import annotations from google.adk.tools.bigquery import BigQueryCredentialsConfig -from google.adk.tools.bigquery import BigQueryTool from google.adk.tools.bigquery import BigQueryToolset +from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.adk.tools.google_tool import GoogleTool import pytest @@ -30,12 +31,18 @@ async def test_bigquery_toolset_tools_default(): credentials_config = BigQueryCredentialsConfig( client_id="abc", client_secret="def" ) - toolset = BigQueryToolset(credentials_config=credentials_config) + toolset = BigQueryToolset( + credentials_config=credentials_config, bigquery_tool_config=None + ) + # Verify that the tool config is initialized to default values. + assert isinstance(toolset._tool_settings, BigQueryToolConfig) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == BigQueryToolConfig().__dict__ # pylint: disable=protected-access + tools = await toolset.get_tools() assert tools is not None assert len(tools) == 5 - assert all([isinstance(tool, BigQueryTool) for tool in tools]) + assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ "list_dataset_ids", @@ -77,7 +84,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools): assert tools is not None assert len(tools) == len(selected_tools) - assert all([isinstance(tool, BigQueryTool) for tool in tools]) + assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set(selected_tools) actual_tool_names = set([tool.name for tool in tools]) @@ -114,7 +121,7 @@ async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools): assert tools is not None assert len(tools) == len(returned_tools) - assert all([isinstance(tool, BigQueryTool) for tool in tools]) + assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set(returned_tools) actual_tool_names = set([tool.name for tool in tools]) diff --git a/tests/unittests/tools/spanner/__init__ b/tests/unittests/tools/spanner/__init__ new file mode 100644 index 00000000..60cac4f4 --- /dev/null +++ b/tests/unittests/tools/spanner/__init__ @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/tests/unittests/tools/spanner/test_spanner_client.py b/tests/unittests/tools/spanner/test_spanner_client.py new file mode 100644 index 00000000..0aaf6967 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_client.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import re +from unittest import mock + +from google.adk.tools.spanner.client import get_spanner_client +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2.credentials import Credentials +import pytest + + +def test_spanner_client_project(): + """Test spanner client project.""" + # Trigger the spanner client creation + client = get_spanner_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the client has the desired project set + assert client.project == "test-gcp-project" + + +def test_spanner_client_project_set_explicit(): + """Test spanner client creation does not invoke default auth.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the spanner client creation + client = get_spanner_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_spanner_client_project_set_with_default_auth(): + """Test spanner client creation invokes default auth to set the project.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate credentials + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Simulate output of the default auth + mock_default_auth.return_value = (mock_creds, "test-gcp-project") + + # Trigger the spanner client creation + client = get_spanner_client( + project=None, + credentials=mock_creds, + ) + + # Verify that default auth was called once to set the client project + mock_default_auth.assert_called_once() + assert client.project == "test-gcp-project" + + +def test_spanner_client_project_set_with_env(): + """Test spanner client creation sets the project from environment variable.""" + # Let's simulate the project set in environment variables + with mock.patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True + ): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the spanner client creation + client = get_spanner_client( + project=None, + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_spanner_client_user_agent(): + """Test spanner client user agent.""" + # Patch the Client constructor + with mock.patch( + "google.cloud.spanner.Client", autospec=True + ) as mock_client_class: + # The mock instance that will be returned by spanner.Client() + mock_instance = mock_client_class.return_value + # The real spanner.Client instance has a `_client_info` attribute. + # We need to add it to our mock instance so that the user_agent can be set. + mock_instance._client_info = mock.Mock() + + # Call the function that creates the client + client = get_spanner_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the Spanner Client was instantiated. + mock_client_class.assert_called_once_with( + project="test-gcp-project", + credentials=mock.ANY, + ) + + # Verify that the user_agent was set on the client instance. + # The client returned by get_spanner_client is the mock instance. + assert re.search( + r"adk-spanner-tool google-adk/([0-9A-Za-z._\-+/]+)", + client._client_info.user_agent, + ) diff --git a/tests/unittests/tools/spanner/test_spanner_credentials.py b/tests/unittests/tools/spanner/test_spanner_credentials.py new file mode 100644 index 00000000..19430e14 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_credentials.py @@ -0,0 +1,54 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +# Mock the Google OAuth and API dependencies +import google.auth.credentials +import google.oauth2.credentials +import pytest + + +class TestSpannerCredentials: + """Test suite for Spanner credentials configuration validation. + + This class tests the credential configuration logic that ensures + either existing credentials or client ID/secret pairs are provided. + """ + + def test_valid_credentials_object_oauth2_credentials(self): + """Test that providing valid Credentials object works correctly with google.oauth2.credentials.Credentials. + + When a user already has valid OAuth credentials, they should be able + to pass them directly without needing to provide client ID/secret. + """ + # Create a mock oauth2 credentials object + oauth2_creds = google.oauth2.credentials.Credentials( + "test_token", + client_id="test_client_id", + client_secret="test_client_secret", + scopes=[], + ) + + config = SpannerCredentialsConfig(credentials=oauth2_creds) + + # Verify that the credentials are properly stored and attributes are + # extracted + assert config.credentials == oauth2_creds + assert config.client_id == "test_client_id" + assert config.client_secret == "test_client_secret" + assert config.scopes == [ + "https://www.googleapis.com/auth/spanner.data", + ] + + assert config._token_cache_key == "spanner_token_cache" # pylint: disable=protected-access diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py new file mode 100644 index 00000000..f74922b2 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py @@ -0,0 +1,27 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.spanner.settings import SpannerToolSettings +import pytest + + +def test_spanner_tool_settings_experimental_warning(): + """Test SpannerToolSettings experimental warning.""" + with pytest.warns( + UserWarning, + match="Tool settings defaults may have breaking change in the future.", + ): + SpannerToolSettings() diff --git a/tests/unittests/tools/spanner/test_spanner_toolset.py b/tests/unittests/tools/spanner/test_spanner_toolset.py new file mode 100644 index 00000000..73a780f8 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_toolset.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner import SpannerCredentialsConfig +from google.adk.tools.spanner import SpannerToolset +from google.adk.tools.spanner.settings import SpannerToolSettings +import pytest + + +@pytest.mark.asyncio +async def test_spanner_toolset_tools_default(): + """Test default Spanner toolset. + + This test verifies the behavior of the Spanner toolset when no filter is + specified. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = SpannerToolset(credentials_config=credentials_config) + assert isinstance(toolset._tool_settings, SpannerToolSettings) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == SpannerToolSettings().__dict__ # pylint: disable=protected-access + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 6 + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set([ + "list_table_names", + "list_table_indexes", + "list_table_index_columns", + "list_named_schemas", + "get_table_schema", + "execute_sql", + ]) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + "selected_tools", + [ + pytest.param([], id="None"), + pytest.param( + ["list_table_names", "get_table_schema"], + id="table-metadata", + ), + pytest.param(["execute_sql"], id="query"), + ], +) +@pytest.mark.asyncio +async def test_spanner_toolset_selective(selected_tools): + """Test selective Spanner toolset. + + This test verifies the behavior of the Spanner toolset when a filter is + specified. + + Args: + selected_tools: A list of tool names to filter. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=selected_tools, + spanner_tool_settings=SpannerToolSettings(), + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(selected_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(selected_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param(["unknown"], [], id="all-unknown"), + pytest.param( + ["unknown", "execute_sql"], + ["execute_sql"], + id="mixed-known-unknown", + ), + ], +) +@pytest.mark.asyncio +async def test_spanner_toolset_unknown_tool(selected_tools, returned_tools): + """Test Spanner toolset with unknown tools. + + This test verifies the behavior of the Spanner toolset when unknown tools are + specified in the filter. + + Args: + selected_tools: A list of tool names to filter, including unknown ones. + returned_tools: A list of tool names that are expected to be returned. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=selected_tools, + spanner_tool_settings=SpannerToolSettings(), + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(returned_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param( + ["execute_sql", "list_table_names"], + ["list_table_names"], + id="read-not-added", + ), + pytest.param( + ["list_table_names", "list_table_indexes"], + ["list_table_names", "list_table_indexes"], + id="no-effect", + ), + ], +) +@pytest.mark.asyncio +async def test_spanner_toolset_without_read_capability( + selected_tools, returned_tools +): + """Test Spanner toolset without read capability. + + This test verifies the behavior of the Spanner toolset when read capability is + not enabled. + + Args: + selected_tools: A list of tool names to filter. + returned_tools: A list of tool names that are expected to be returned. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + spanner_tool_settings = SpannerToolSettings(capabilities=[]) + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=selected_tools, + spanner_tool_settings=spanner_tool_settings, + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(returned_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/test_base_google_credentials_manager.py similarity index 96% rename from tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py rename to tests/unittests/tools/test_base_google_credentials_manager.py index 73ffa3bd..de568543 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py +++ b/tests/unittests/tools/test_base_google_credentials_manager.py @@ -18,9 +18,9 @@ from unittest.mock import Mock from unittest.mock import patch from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools._google_credentials import GoogleCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig -from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials as AuthCredentials from google.auth.exceptions import RefreshError @@ -29,8 +29,8 @@ from google.oauth2.credentials import Credentials as OAuthCredentials import pytest -class TestBigQueryCredentialsManager: - """Test suite for BigQueryCredentialsManager OAuth flow handling. +class TestGoogleCredentialsManager: + """Test suite for GoogleCredentialsManager OAuth flow handling. This class tests the complex credential management logic including credential validation, refresh, OAuth flow orchestration, and the @@ -63,7 +63,7 @@ class TestBigQueryCredentialsManager: @pytest.fixture def manager(self, credentials_config): """Create a credentials manager instance for testing.""" - return BigQueryCredentialsManager(credentials_config) + return GoogleCredentialsManager(credentials_config) @pytest.mark.parametrize( ("credentials_class",), @@ -336,7 +336,7 @@ class TestBigQueryCredentialsManager: # Use the full module path as it appears in the project structure with patch( - "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", + "google.adk.tools._google_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: result = await manager.get_valid_credentials(mock_tool_context) @@ -388,7 +388,7 @@ class TestBigQueryCredentialsManager: credential manager, avoiding redundant OAuth flows. """ # Create first manager instance and simulate OAuth completion - manager1 = BigQueryCredentialsManager(credentials_config) + manager1 = GoogleCredentialsManager(credentials_config) # Mock OAuth response for first manager mock_auth_response = Mock() @@ -412,7 +412,7 @@ class TestBigQueryCredentialsManager: # Use the correct module path - without the 'src.' prefix with patch( - "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", + "google.adk.tools._google_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: # Complete OAuth flow with first manager @@ -424,7 +424,7 @@ class TestBigQueryCredentialsManager: assert cached_creds_json == mock_creds_json # Create second manager instance (simulating new request/session) - manager2 = BigQueryCredentialsManager(credentials_config) + manager2 = GoogleCredentialsManager(credentials_config) credentials_config.credentials = None # Reset auth response to None (no new OAuth flow available) @@ -432,7 +432,7 @@ class TestBigQueryCredentialsManager: # Mock the from_authorized_user_info method for the second manager with patch( - "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" + "google.adk.tools._google_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = True diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool.py b/tests/unittests/tools/test_google_tool.py similarity index 84% rename from tests/unittests/tools/bigquery/test_bigquery_tool.py rename to tests/unittests/tools/test_google_tool.py index 5b1441d4..fb9da070 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool.py +++ b/tests/unittests/tools/test_google_tool.py @@ -16,18 +16,19 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.tools._google_credentials import GoogleCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig -from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager -from google.adk.tools.bigquery.bigquery_tool import BigQueryTool from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner.settings import SpannerToolSettings from google.adk.tools.tool_context import ToolContext # Mock the Google OAuth and API dependencies from google.oauth2.credentials import Credentials import pytest -class TestBigQueryTool: - """Test suite for BigQueryTool OAuth integration and execution. +class TestGoogleTool: + """Test suite for GoogleTool OAuth integration and execution. This class tests the high-level tool execution logic that combines credential management with actual function execution. @@ -88,18 +89,18 @@ class TestBigQueryTool: def test_tool_initialization_with_credentials( self, sample_function, credentials_config ): - """Test that BigQueryTool initializes correctly with credentials. + """Test that GoogleTool initializes correctly with credentials. The tool should properly inherit from FunctionTool while adding Google API specific credential management capabilities. """ - tool = BigQueryTool( + tool = GoogleTool( func=sample_function, credentials_config=credentials_config ) assert tool.func == sample_function assert tool._credentials_manager is not None - assert isinstance(tool._credentials_manager, BigQueryCredentialsManager) + assert isinstance(tool._credentials_manager, GoogleCredentialsManager) # Verify that 'credentials' parameter is ignored in function signature analysis assert "credentials" in tool._ignore_params @@ -109,7 +110,7 @@ class TestBigQueryTool: Some tools might handle authentication externally or use service accounts, so credential management should be optional. """ - tool = BigQueryTool(func=sample_function, credentials_config=None) + tool = GoogleTool(func=sample_function, credentials_config=None) assert tool.func == sample_function assert tool._credentials_manager is None @@ -123,7 +124,7 @@ class TestBigQueryTool: This tests the main happy path where credentials are available and the underlying function executes successfully. """ - tool = BigQueryTool( + tool = GoogleTool( func=sample_function, credentials_config=credentials_config ) @@ -152,7 +153,7 @@ class TestBigQueryTool: When credentials aren't available and OAuth flow is needed, the tool should return a user-friendly message rather than failing. """ - tool = BigQueryTool( + tool = GoogleTool( func=sample_function, credentials_config=credentials_config ) @@ -178,7 +179,7 @@ class TestBigQueryTool: Tools without credential managers should execute normally, passing None for credentials if the function accepts them. """ - tool = BigQueryTool(func=sample_function, credentials_config=None) + tool = GoogleTool(func=sample_function, credentials_config=None) result = await tool.run_async( args={"param1": "test_value"}, tool_context=mock_tool_context @@ -196,7 +197,7 @@ class TestBigQueryTool: The tool should correctly detect and execute async functions, which is important for tools that make async API calls. """ - tool = BigQueryTool( + tool = GoogleTool( func=async_sample_function, credentials_config=credentials_config ) @@ -227,7 +228,7 @@ class TestBigQueryTool: def failing_function(param1: str, credentials: Credentials = None) -> dict: raise ValueError("Something went wrong") - tool = BigQueryTool( + tool = GoogleTool( func=failing_function, credentials_config=credentials_config ) @@ -259,7 +260,7 @@ class TestBigQueryTool: ) -> dict: return {"success": True} - tool = BigQueryTool( + tool = GoogleTool( func=complex_function, credentials_config=credentials_config ) @@ -270,7 +271,7 @@ class TestBigQueryTool: assert "optional_param" not in mandatory_args @pytest.mark.parametrize( - "input_config, expected_config", + "input_settings, expected_settings", [ pytest.param( BigQueryToolConfig( @@ -281,22 +282,36 @@ class TestBigQueryTool: ), id="with_provided_config", ), - pytest.param( - None, - BigQueryToolConfig(), - id="with_none_config_creates_default", - ), ], ) - def test_tool_config_initialization(self, input_config, expected_config): - """Tests that self._tool_config is correctly initialized by comparing its + def test_tool_bigquery_config_initialization( + self, input_settings, expected_settings + ): + """Tests that self._tool_settings is correctly initialized by comparing its final state to an expected configuration object. """ # 1. Initialize the tool with the parameterized config - tool = BigQueryTool(func=None, bigquery_tool_config=input_config) + tool = GoogleTool(func=None, tool_settings=input_settings) # 2. Assert that the tool's config has the same attribute values # as the expected config. Comparing the __dict__ is a robust # way to check for value equality. - assert tool._tool_config.__dict__ == expected_config.__dict__ # pylint: disable=protected-access + assert tool._tool_settings.__dict__ == expected_settings.__dict__ # pylint: disable=protected-access + + @pytest.mark.parametrize( + "input_settings, expected_settings", + [ + pytest.param( + SpannerToolSettings(max_executed_query_result_rows=10), + SpannerToolSettings(max_executed_query_result_rows=10), + id="with_provided_settings", + ), + ], + ) + def test_tool_spanner_settings_initialization( + self, input_settings, expected_settings + ): + """Tests that self._tool_settings is correctly initialized with SpannerToolSettings by comparing its final state to an expected configuration object.""" + tool = GoogleTool(func=None, tool_settings=input_settings) + assert tool._tool_settings.__dict__ == expected_settings.__dict__ # pylint: disable=protected-access