diff --git a/pyproject.toml b/pyproject.toml index 2d1414af..d00951f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "google-api-python-client>=2.157.0", # Google API client discovery "google-cloud-aiplatform[agent_engines]>=1.95.1", # For VertexAI integrations, e.g. example store. "google-cloud-secret-manager>=2.22.0", # Fetching secrets in RestAPI Tool + "google-cloud-spanner>=3.56.0", # For Spanner database "google-cloud-speech>=2.30.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service "google-genai>=1.21.1", # Google GenAI SDK diff --git a/src/google/adk/tools/_google_credentials.py b/src/google/adk/tools/_google_credentials.py new file mode 100644 index 00000000..c5e25a77 --- /dev/null +++ b/src/google/adk/tools/_google_credentials.py @@ -0,0 +1,252 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List +from typing import Optional + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +import google.auth.credentials +from google.auth.exceptions import RefreshError +from google.auth.transport.requests import Request +import google.oauth2.credentials +from pydantic import BaseModel +from pydantic import model_validator + +from ..auth.auth_credential import AuthCredential +from ..auth.auth_credential import AuthCredentialTypes +from ..auth.auth_credential import OAuth2Auth +from ..auth.auth_tool import AuthConfig +from ..utils.feature_decorator import experimental +from .tool_context import ToolContext + + +@experimental +class BaseGoogleCredentialsConfig(BaseModel): + """Base Google Credentials Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ + + # Configure the model to allow arbitrary types like Credentials + model_config = {"arbitrary_types_allowed": True} + + credentials: Optional[google.auth.credentials.Credentials] = None + """The existing auth credentials to use. If set, this credential will be used + for every end user, end users don't need to be involved in the oauthflow. This + field is mutually exclusive with client_id, client_secret and scopes. + Don't set this field unless you are sure this credential has the permission to + access every end user's data. + + Example usage 1: When the agent is deployed in Google Cloud environment and + the service account (used as application default credentials) has access to + all the required Google Cloud resource. Setting this credential to allow user + to access the Google Cloud resource without end users going through oauth + flow. + + To get application default credential, use: `google.auth.default(...)`. See + more details in + https://cloud.google.com/docs/authentication/application-default-credentials. + + Example usage 2: When the agent wants to access the user's Google Cloud + resources using the service account key credentials. + + To load service account key credentials, use: + `google.auth.load_credentials_from_file(...)`. See more details in + https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. + + When the deployed environment cannot provide a pre-existing credential, + consider setting below client_id, client_secret and scope for end users to go + through oauth flow, so that agent can access the user data. + """ + client_id: Optional[str] = None + """the oauth client ID to use.""" + client_secret: Optional[str] = None + """the oauth client secret to use.""" + scopes: Optional[List[str]] = None + """the scopes to use.""" + + _token_cache_key: Optional[str] = None + """The key to cache the token in the tool context.""" + + @model_validator(mode="after") + def __post_init__(self) -> BaseGoogleCredentialsConfig: + """Validate that either credentials or client ID/secret are provided.""" + if not self.credentials and (not self.client_id or not self.client_secret): + raise ValueError( + "Must provide either credentials or client_id and client_secret pair." + ) + if self.credentials and ( + self.client_id or self.client_secret or self.scopes + ): + raise ValueError( + "Cannot provide both existing credentials and" + " client_id/client_secret/scopes." + ) + + if self.credentials and isinstance( + self.credentials, google.oauth2.credentials.Credentials + ): + self.client_id = self.credentials.client_id + self.client_secret = self.credentials.client_secret + self.scopes = self.credentials.scopes + + return self + + +class GoogleCredentialsManager: + """Manages Google API credentials with automatic refresh and OAuth flow handling. + + This class centralizes credential management so multiple tools can share + the same authenticated session without duplicating OAuth logic. + """ + + def __init__( + self, + credentials_config: BaseGoogleCredentialsConfig, + ): + """Initialize the credential manager. + + Args: + credentials_config: Credentials containing client id and client secrete + or default credentials + """ + self.credentials_config = credentials_config + + async def get_valid_credentials( + self, tool_context: ToolContext + ) -> Optional[google.auth.credentials.Credentials]: + """Get valid credentials, handling refresh and OAuth flow as needed. + + Args: + tool_context: The tool context for OAuth flow and state management + + Returns: + Valid Credentials object, or None if OAuth flow is needed + """ + # First, try to get credentials from the tool context + creds_json = ( + tool_context.state.get(self.credentials_config._token_cache_key, None) + if self.credentials_config._token_cache_key + else None + ) + creds = ( + google.oauth2.credentials.Credentials.from_authorized_user_info( + json.loads(creds_json), self.credentials_config.scopes + ) + if creds_json + else None + ) + + # If credentails are empty use the default credential + if not creds: + creds = self.credentials_config.credentials + + # If non-oauth credentials are provided then use them as is. This helps + # in flows such as service account keys + if creds and not isinstance(creds, google.oauth2.credentials.Credentials): + return creds + + # Check if we have valid credentials + if creds and creds.valid: + return creds + + # Try to refresh expired credentials + if creds and creds.expired and creds.refresh_token: + try: + creds.refresh(Request()) + if creds.valid: + # Cache the refreshed credentials if token cache key is set + if self.credentials_config._token_cache_key: + tool_context.state[self.credentials_config._token_cache_key] = ( + creds.to_json() + ) + return creds + except RefreshError: + # Refresh failed, need to re-authenticate + pass + + # Need to perform OAuth flow + return await self._perform_oauth_flow(tool_context) + + async def _perform_oauth_flow( + self, tool_context: ToolContext + ) -> Optional[google.oauth2.credentials.Credentials]: + """Perform OAuth flow to get new credentials. + + Args: + tool_context: The tool context for OAuth flow + + Returns: + New Credentials object, or None if flow is in progress + """ + + # Create OAuth configuration + auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://accounts.google.com/o/oauth2/auth", + tokenUrl="https://oauth2.googleapis.com/token", + scopes={ + scope: f"Access to {scope}" + for scope in self.credentials_config.scopes + }, + ) + ) + ) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id=self.credentials_config.client_id, + client_secret=self.credentials_config.client_secret, + ), + ) + + # Check if OAuth response is available + auth_response = tool_context.get_auth_response( + AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential) + ) + + if auth_response: + # OAuth flow completed, create credentials + creds = google.oauth2.credentials.Credentials( + token=auth_response.oauth2.access_token, + refresh_token=auth_response.oauth2.refresh_token, + token_uri=auth_scheme.flows.authorizationCode.tokenUrl, + client_id=self.credentials_config.client_id, + client_secret=self.credentials_config.client_secret, + scopes=list(self.credentials_config.scopes), + ) + + # Cache the new credentials if token cache key is set + if self.credentials_config._token_cache_key: + tool_context.state[self.credentials_config._token_cache_key] = ( + creds.to_json() + ) + + return creds + else: + # Request OAuth flow + tool_context.request_credential( + AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential, + ) + ) + return None diff --git a/src/google/adk/tools/bigquery/__init__.py b/src/google/adk/tools/bigquery/__init__.py index 3db5a5ec..9e6b1166 100644 --- a/src/google/adk/tools/bigquery/__init__.py +++ b/src/google/adk/tools/bigquery/__init__.py @@ -28,11 +28,9 @@ definition. The rationales to have customized tool are: """ from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_tool import BigQueryTool from .bigquery_toolset import BigQueryToolset __all__ = [ - "BigQueryTool", "BigQueryToolset", "BigQueryCredentialsConfig", ] diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index d0f3abe0..00df6618 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -14,227 +14,28 @@ from __future__ import annotations -import json -from typing import List -from typing import Optional - -from fastapi.openapi.models import OAuth2 -from fastapi.openapi.models import OAuthFlowAuthorizationCode -from fastapi.openapi.models import OAuthFlows -import google.auth.credentials -from google.auth.exceptions import RefreshError -from google.auth.transport.requests import Request -import google.oauth2.credentials -from pydantic import BaseModel -from pydantic import model_validator - -from ...auth.auth_credential import AuthCredential -from ...auth.auth_credential import AuthCredentialTypes -from ...auth.auth_credential import OAuth2Auth -from ...auth.auth_tool import AuthConfig from ...utils.feature_decorator import experimental -from ..tool_context import ToolContext +from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @experimental -class BigQueryCredentialsConfig(BaseModel): - """Configuration for Google API tools (Experimental). +class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): + """BigQuery Credentials Configuration for Google API tools (Experimental). Please do not use this in production, as it may be deprecated later. """ - # Configure the model to allow arbitrary types like Credentials - model_config = {"arbitrary_types_allowed": True} - - credentials: Optional[google.auth.credentials.Credentials] = None - """The existing auth credentials to use. If set, this credential will be used - for every end user, end users don't need to be involved in the oauthflow. This - field is mutually exclusive with client_id, client_secret and scopes. - Don't set this field unless you are sure this credential has the permission to - access every end user's data. - - Example usage 1: When the agent is deployed in Google Cloud environment and - the service account (used as application default credentials) has access to - all the required BigQuery resource. Setting this credential to allow user to - access the BigQuery resource without end users going through oauth flow. - - To get application default credential, use: `google.auth.default(...)`. See more - details in https://cloud.google.com/docs/authentication/application-default-credentials. - - Example usage 2: When the agent wants to access the user's BigQuery resources - using the service account key credentials. - - To load service account key credentials, use: `google.auth.load_credentials_from_file(...)`. - See more details in https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys. - - When the deployed environment cannot provide a pre-existing credential, - consider setting below client_id, client_secret and scope for end users to go - through oauth flow, so that agent can access the user data. - """ - client_id: Optional[str] = None - """the oauth client ID to use.""" - client_secret: Optional[str] = None - """the oauth client secret to use.""" - scopes: Optional[List[str]] = None - """the scopes to use.""" - - @model_validator(mode="after") def __post_init__(self) -> BigQueryCredentialsConfig: - """Validate that either credentials or client ID/secret are provided.""" - if not self.credentials and (not self.client_id or not self.client_secret): - raise ValueError( - "Must provide either credentials or client_id and client_secret pair." - ) - if self.credentials and ( - self.client_id or self.client_secret or self.scopes - ): - raise ValueError( - "Cannot provide both existing credentials and" - " client_id/client_secret/scopes." - ) - - if self.credentials and isinstance( - self.credentials, google.oauth2.credentials.Credentials - ): - self.client_id = self.credentials.client_id - self.client_secret = self.credentials.client_secret - self.scopes = self.credentials.scopes + """Populate default scope if scopes is None.""" + super().__post_init__() if not self.scopes: self.scopes = BIGQUERY_DEFAULT_SCOPE + # Set the token cache key + self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY + return self - - -class BigQueryCredentialsManager: - """Manages Google API credentials with automatic refresh and OAuth flow handling. - - This class centralizes credential management so multiple tools can share - the same authenticated session without duplicating OAuth logic. - """ - - def __init__(self, credentials_config: BigQueryCredentialsConfig): - """Initialize the credential manager. - - Args: - credentials_config: Credentials containing client id and client secrete - or default credentials - """ - self.credentials_config = credentials_config - - async def get_valid_credentials( - self, tool_context: ToolContext - ) -> Optional[google.auth.credentials.Credentials]: - """Get valid credentials, handling refresh and OAuth flow as needed. - - Args: - tool_context: The tool context for OAuth flow and state management - - Returns: - Valid Credentials object, or None if OAuth flow is needed - """ - # First, try to get credentials from the tool context - creds_json = tool_context.state.get(BIGQUERY_TOKEN_CACHE_KEY, None) - creds = ( - google.oauth2.credentials.Credentials.from_authorized_user_info( - json.loads(creds_json), self.credentials_config.scopes - ) - if creds_json - else None - ) - - # If credentails are empty use the default credential - if not creds: - creds = self.credentials_config.credentials - - # If non-oauth credentials are provided then use them as is. This helps - # in flows such as service account keys - if creds and not isinstance(creds, google.oauth2.credentials.Credentials): - return creds - - # Check if we have valid credentials - if creds and creds.valid: - return creds - - # Try to refresh expired credentials - if creds and creds.expired and creds.refresh_token: - try: - creds.refresh(Request()) - if creds.valid: - # Cache the refreshed credentials - tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json() - return creds - except RefreshError: - # Refresh failed, need to re-authenticate - pass - - # Need to perform OAuth flow - return await self._perform_oauth_flow(tool_context) - - async def _perform_oauth_flow( - self, tool_context: ToolContext - ) -> Optional[google.oauth2.credentials.Credentials]: - """Perform OAuth flow to get new credentials. - - Args: - tool_context: The tool context for OAuth flow - required_scopes: Set of required OAuth scopes - - Returns: - New Credentials object, or None if flow is in progress - """ - - # Create OAuth configuration - auth_scheme = OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://accounts.google.com/o/oauth2/auth", - tokenUrl="https://oauth2.googleapis.com/token", - scopes={ - scope: f"Access to {scope}" - for scope in self.credentials_config.scopes - }, - ) - ) - ) - - auth_credential = AuthCredential( - auth_type=AuthCredentialTypes.OAUTH2, - oauth2=OAuth2Auth( - client_id=self.credentials_config.client_id, - client_secret=self.credentials_config.client_secret, - ), - ) - - # Check if OAuth response is available - auth_response = tool_context.get_auth_response( - AuthConfig(auth_scheme=auth_scheme, raw_auth_credential=auth_credential) - ) - - if auth_response: - # OAuth flow completed, create credentials - creds = google.oauth2.credentials.Credentials( - token=auth_response.oauth2.access_token, - refresh_token=auth_response.oauth2.refresh_token, - token_uri=auth_scheme.flows.authorizationCode.tokenUrl, - client_id=self.credentials_config.client_id, - client_secret=self.credentials_config.client_secret, - scopes=list(self.credentials_config.scopes), - ) - - # Cache the new credentials - tool_context.state[BIGQUERY_TOKEN_CACHE_KEY] = creds.to_json() - - return creds - else: - # Request OAuth flow - tool_context.request_credential( - AuthConfig( - auth_scheme=auth_scheme, - raw_auth_credential=auth_credential, - ) - ) - return None diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 313cf499..4b84a270 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -26,9 +26,9 @@ from . import query_tool from ...tools.base_tool import BaseTool from ...tools.base_toolset import BaseToolset from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool from ...utils.feature_decorator import experimental from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_tool import BigQueryTool from .config import BigQueryToolConfig @@ -45,7 +45,9 @@ class BigQueryToolset(BaseToolset): ): self.tool_filter = tool_filter self._credentials_config = credentials_config - self._tool_config = bigquery_tool_config + self._tool_settings = ( + bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig() + ) def _is_tool_selected( self, tool: BaseTool, readonly_context: ReadonlyContext @@ -67,17 +69,17 @@ class BigQueryToolset(BaseToolset): ) -> List[BaseTool]: """Get tools from the toolset.""" all_tools = [ - BigQueryTool( + GoogleTool( func=func, credentials_config=self._credentials_config, - bigquery_tool_config=self._tool_config, + tool_settings=self._tool_settings, ) for func in [ metadata_tool.get_dataset_info, metadata_tool.get_table_info, metadata_tool.list_dataset_ids, metadata_tool.list_table_ids, - query_tool.get_execute_sql(self._tool_config), + query_tool.get_execute_sql(self._tool_settings), ] ] diff --git a/src/google/adk/tools/bigquery/data_insights_tool.py b/src/google/adk/tools/bigquery/data_insights_tool.py index a2fdca08..2af2249b 100644 --- a/src/google/adk/tools/bigquery/data_insights_tool.py +++ b/src/google/adk/tools/bigquery/data_insights_tool.py @@ -30,7 +30,7 @@ def ask_data_insights( user_query_with_context: str, table_references: List[Dict[str, str]], credentials: Credentials, - config: BigQueryToolConfig, + settings: BigQueryToolConfig, ) -> Dict[str, Any]: """Answers questions about structured data in BigQuery tables using natural language. @@ -53,7 +53,7 @@ def ask_data_insights( table_references (List[Dict[str, str]]): A list of dictionaries, each specifying a BigQuery table to be used as context for the question. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. Returns: A dictionary with two keys: @@ -135,7 +135,7 @@ def ask_data_insights( } resp = _get_stream( - ca_url, ca_payload, headers, config.max_query_result_rows + ca_url, ca_payload, headers, settings.max_query_result_rows ) except Exception as ex: # pylint: disable=broad-except return { diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index f989a9f7..5ceebc4c 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -36,7 +36,7 @@ def execute_sql( project_id: str, query: str, credentials: Credentials, - config: BigQueryToolConfig, + settings: BigQueryToolConfig, tool_context: ToolContext, ) -> dict: """Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -46,7 +46,7 @@ def execute_sql( executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -89,7 +89,7 @@ def execute_sql( # BigQuery connection properties where applicable bq_connection_properties = None - if not config or config.write_mode == WriteMode.BLOCKED: + if not settings or settings.write_mode == WriteMode.BLOCKED: dry_run_query_job = bq_client.query( query, project=project_id, @@ -100,7 +100,7 @@ def execute_sql( "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } - elif config.write_mode == WriteMode.PROTECTED: + elif settings.write_mode == WriteMode.PROTECTED: # In protected write mode, write operation only to a temporary artifact is # allowed. This artifact must have been created in a BigQuery session. In # such a scenario the session info (session id and the anonymous dataset @@ -161,7 +161,7 @@ def execute_sql( query, job_config=job_config, project=project_id, - max_results=config.max_query_result_rows, + max_results=settings.max_query_result_rows, ) rows = [] for row in row_iterator: @@ -177,8 +177,8 @@ def execute_sql( result = {"status": "SUCCESS", "rows": rows} if ( - config.max_query_result_rows is not None - and len(rows) == config.max_query_result_rows + settings.max_query_result_rows is not None + and len(rows) == settings.max_query_result_rows ): result["result_is_likely_truncated"] = True return result @@ -462,19 +462,19 @@ _execute_sql_protecetd_write_examples = """ """ -def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]: - """Get the execute_sql tool customized as per the given tool config. +def get_execute_sql(settings: BigQueryToolConfig) -> Callable[..., dict]: + """Get the execute_sql tool customized as per the given tool settings. Args: - config: BigQuery tool configuration indicating the behavior of the + settings: BigQuery tool settings indicating the behavior of the execute_sql tool. Returns: callable[..., dict]: A version of the execute_sql tool respecting the tool - config. + settings. """ - if not config or config.write_mode == WriteMode.BLOCKED: + if not settings or settings.write_mode == WriteMode.BLOCKED: return execute_sql # Create a new function object using the original function's code and globals. @@ -495,7 +495,7 @@ def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]: functools.update_wrapper(execute_sql_wrapper, execute_sql) # Now, set the new docstring - if config.write_mode == WriteMode.PROTECTED: + if settings.write_mode == WriteMode.PROTECTED: examples = _execute_sql_protecetd_write_examples else: examples = _execute_sql_write_examples diff --git a/src/google/adk/tools/bigquery/bigquery_tool.py b/src/google/adk/tools/google_tool.py similarity index 77% rename from src/google/adk/tools/bigquery/bigquery_tool.py rename to src/google/adk/tools/google_tool.py index 0b231edb..9776fa0f 100644 --- a/src/google/adk/tools/bigquery/bigquery_tool.py +++ b/src/google/adk/tools/google_tool.py @@ -20,19 +20,19 @@ from typing import Callable from typing import Optional from google.auth.credentials import Credentials +from pydantic import BaseModel from typing_extensions import override -from ...utils.feature_decorator import experimental -from ..function_tool import FunctionTool -from ..tool_context import ToolContext -from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_credentials import BigQueryCredentialsManager -from .config import BigQueryToolConfig +from ..utils.feature_decorator import experimental +from ._google_credentials import BaseGoogleCredentialsConfig +from ._google_credentials import GoogleCredentialsManager +from .function_tool import FunctionTool +from .tool_context import ToolContext @experimental -class BigQueryTool(FunctionTool): - """GoogleApiTool class for tools that call Google APIs. +class GoogleTool(FunctionTool): + """GoogleTool class for tools that call Google APIs. This class is for developers to handcraft customized Google API tools rather than auto generate Google API tools based on API specs. @@ -46,8 +46,8 @@ class BigQueryTool(FunctionTool): self, func: Callable[..., Any], *, - credentials_config: Optional[BigQueryCredentialsConfig] = None, - bigquery_tool_config: Optional[BigQueryToolConfig] = None, + credentials_config: Optional[BaseGoogleCredentialsConfig] = None, + tool_settings: Optional[BaseModel] = None, ): """Initialize the Google API tool. @@ -56,18 +56,18 @@ class BigQueryTool(FunctionTool): 'credential" parameter credentials_config: credentials config used to call Google API. If None, then we don't hanlde the auth logic + tool_settings: Tool-specific settings. This settings should be provided + by each toolset that uses this class to create customized tools. """ super().__init__(func=func) self._ignore_params.append("credentials") - self._ignore_params.append("config") + self._ignore_params.append("settings") self._credentials_manager = ( - BigQueryCredentialsManager(credentials_config) + GoogleCredentialsManager(credentials_config) if credentials_config else None ) - self._tool_config = ( - bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig() - ) + self._tool_settings = tool_settings @override async def run_async( @@ -96,7 +96,7 @@ class BigQueryTool(FunctionTool): # Execute the tool's specific logic with valid credentials return await self._run_async_with_credential( - credentials, self._tool_config, args, tool_context + credentials, self._tool_settings, args, tool_context ) except Exception as ex: @@ -108,7 +108,7 @@ class BigQueryTool(FunctionTool): async def _run_async_with_credential( self, credentials: Credentials, - tool_config: BigQueryToolConfig, + tool_settings: BaseModel, args: dict[str, Any], tool_context: ToolContext, ) -> Any: @@ -116,6 +116,7 @@ class BigQueryTool(FunctionTool): Args: credentials: Valid Google OAuth credentials + tool_settings: Tool settings args: Arguments passed to the tool tool_context: Tool execution context @@ -126,6 +127,6 @@ class BigQueryTool(FunctionTool): signature = inspect.signature(self.func) if "credentials" in signature.parameters: args_to_call["credentials"] = credentials - if "config" in signature.parameters: - args_to_call["config"] = tool_config + if "settings" in signature.parameters: + args_to_call["settings"] = tool_settings return await super().run_async(args=args_to_call, tool_context=tool_context) diff --git a/src/google/adk/tools/spanner/__init__.py b/src/google/adk/tools/spanner/__init__.py new file mode 100644 index 00000000..30686b96 --- /dev/null +++ b/src/google/adk/tools/spanner/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spanner Tools (Experimental). + +Spanner Tools under this module are hand crafted and customized while the tools +under google.adk.tools.google_api_tool are auto generated based on API +definition. The rationales to have customized tool are: + +1. A dedicated Spanner toolset to provide an easier, integrated way to interact +with Spanner database and tables for building AI Agent applications quickly. +2. We want to provide more high-level tools like Search, ML.Predict, and Graph +etc. +3. We want to provide extra access guardrails and controls in those tools. +For example, execute_sql can't arbitrarily mutate existing data. +4. We want to provide Spanner best practices and knowledge assistants for ad-hoc +analytics queries. +5. Use Spanner Toolset for more customization and control to interact with +Spanner database and tables. +""" + +from . import spanner_credentials +from .spanner_toolset import SpannerToolset + +SpannerCredentialsConfig = spanner_credentials.SpannerCredentialsConfig +__all__ = [ + "SpannerToolset", + "SpannerCredentialsConfig", +] diff --git a/src/google/adk/tools/spanner/client.py b/src/google/adk/tools/spanner/client.py new file mode 100644 index 00000000..aecba9e9 --- /dev/null +++ b/src/google/adk/tools/spanner/client.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.auth.credentials import Credentials +from google.cloud import spanner + +from ... import version + +USER_AGENT = f"adk-spanner-tool google-adk/{version.__version__}" + + +def get_spanner_client( + *, project: str, credentials: Credentials +) -> spanner.Client: + """Get a Spanner client.""" + + spanner_client = spanner.Client(project=project, credentials=credentials) + spanner_client._client_info.user_agent = USER_AGENT + + return spanner_client diff --git a/src/google/adk/tools/spanner/metadata_tool.py b/src/google/adk/tools/spanner/metadata_tool.py new file mode 100644 index 00000000..704df978 --- /dev/null +++ b/src/google/adk/tools/spanner/metadata_tool.py @@ -0,0 +1,503 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json + +from google.auth.credentials import Credentials +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_v1 import param_types as spanner_param_types + +from . import client + + +def list_table_names( + project_id: str, + instance_id: str, + database_id: str, + credentials: Credentials, + named_schema: str = "", +) -> dict: + """List tables within the database. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + credentials (Credentials): The credentials to use for the request. + named_schema (str): The named schema to list tables in. Default is empty + string "" to search for tables in the default schema of the database. + + Returns: + dict: Dictionary with a list of the Spanner table names. + + Examples: + >>> list_tables("my_project", "my_instance", "my_database") + { + "status": "SUCCESS", + "results": [ + "table_1", + "table_2" + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + tables = [] + named_schema = named_schema if named_schema else "_default" + for table in database.list_tables(schema=named_schema): + tables.append(table.table_id) + + return {"status": "SUCCESS", "results": tables} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def get_table_schema( + project_id: str, + instance_id: str, + database_id: str, + table_name: str, + credentials: Credentials, + named_schema: str = "", +) -> dict: + """Get schema information about a Spanner table. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + table_id (str): The Spanner table id. + credentials (Credentials): The credentials to use for the request. + named_schema (str): The named schema to list tables in. Default is empty + string "" to search for tables in the default schema of the database. + + Returns: + dict: Dictionary with the Spanner table schema information. + + Examples: + >>> get_table_schema("my_project", "my_instance", "my_database", + ... "my_table") + { + "status": "SUCCESS", + "results": + { + 'colA': { + 'SPANNER_TYPE': 'STRING(1024)', + 'TABLE_SCHEMA': '', + 'ORDINAL_POSITION': 1, + 'COLUMN_DEFAULT': None, + 'IS_NULLABLE': 'NO', + 'IS_GENERATED': 'NEVER', + 'GENERATION_EXPRESSION': None, + 'IS_STORED': None, + 'KEY_COLUMN_USAGE': { # This part is added if it's a key column + 'CONSTRAINT_NAME': 'PK_Table1', + 'ORDINAL_POSITION': 1, + 'POSITION_IN_UNIQUE_CONSTRAINT': None + } + }, + 'colB': { ... }, + ... + } + """ + + columns_query = """ + SELECT + COLUMN_NAME, + TABLE_SCHEMA, + SPANNER_TYPE, + ORDINAL_POSITION, + COLUMN_DEFAULT, + IS_NULLABLE, + IS_GENERATED, + GENERATION_EXPRESSION, + IS_STORED + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_NAME = @table_name + AND TABLE_SCHEMA = @named_schema + ORDER BY + ORDINAL_POSITION + """ + + key_column_usage_query = """ + SELECT + COLUMN_NAME, + CONSTRAINT_NAME, + ORDINAL_POSITION, + POSITION_IN_UNIQUE_CONSTRAINT + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE + WHERE + TABLE_NAME = @table_name + AND TABLE_SCHEMA = @named_schema + """ + params = {"table_name": table_name, "named_schema": named_schema} + param_types = { + "table_name": spanner_param_types.STRING, + "named_schema": spanner_param_types.STRING, + } + + schema = {} + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported", + } + + with database.snapshot(multi_use=True) as snapshot: + result_set = snapshot.execute_sql( + columns_query, params=params, param_types=param_types + ) + for row in result_set: + ( + column_name, + table_schema, + spanner_type, + ordinal_position, + column_default, + is_nullable, + is_generated, + generation_expression, + is_stored, + ) = row + column_metadata = { + "SPANNER_TYPE": spanner_type, + "TABLE_SCHEMA": table_schema, + "ORDINAL_POSITION": ordinal_position, + "COLUMN_DEFAULT": column_default, + "IS_NULLABLE": is_nullable, + "IS_GENERATED": is_generated, + "GENERATION_EXPRESSION": generation_expression, + "IS_STORED": is_stored, + } + schema[column_name] = column_metadata + + key_column_result_set = snapshot.execute_sql( + key_column_usage_query, params=params, param_types=param_types + ) + for row in key_column_result_set: + ( + column_name, + constraint_name, + ordinal_position, + position_in_unique_constraint, + ) = row + + key_column_properties = { + "CONSTRAINT_NAME": constraint_name, + "ORDINAL_POSITION": ordinal_position, + "POSITION_IN_UNIQUE_CONSTRAINT": position_in_unique_constraint, + } + # Attach key column info to the existing column schema entry + if column_name in schema: + schema[column_name]["KEY_COLUMN_USAGE"] = key_column_properties + + try: + json.dumps(schema) + except: + schema = str(schema) + + return {"status": "SUCCESS", "results": schema} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_table_indexes( + project_id: str, + instance_id: str, + database_id: str, + table_id: str, + credentials: Credentials, +) -> dict: + """Get index information about a Spanner table. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + table_id (str): The Spanner table id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary with a list of the Spanner table index information. + + Examples: + >>> list_table_indexes("my_project", "my_instance", "my_database", + ... "my_table") + { + "status": "SUCCESS", + "results": [ + { + 'INDEX_NAME': 'IDX_MyTable_Column_FC70CD41F3A5FD3A', + 'TABLE_SCHEMA': '', + 'INDEX_TYPE': 'INDEX', + 'PARENT_TABLE_NAME': '', + 'IS_UNIQUE': False, + 'IS_NULL_FILTERED': False, + 'INDEX_STATE': 'READ_WRITE' + }, + { + 'INDEX_NAME': 'PRIMARY_KEY', + 'TABLE_SCHEMA': '', + 'INDEX_TYPE': 'PRIMARY_KEY', + 'PARENT_TABLE_NAME': '', + 'IS_UNIQUE': True, + 'IS_NULL_FILTERED': False, + 'INDEX_STATE': None + } + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + # Using query parameters is best practice to prevent SQL injection, + # even if table_id is typically from a controlled source here. + sql_query = ( + "SELECT INDEX_NAME, TABLE_SCHEMA, INDEX_TYPE," + " PARENT_TABLE_NAME, IS_UNIQUE, IS_NULL_FILTERED, INDEX_STATE " + "FROM INFORMATION_SCHEMA.INDEXES " + "WHERE TABLE_NAME = @table_id " # Use query parameter + ) + params = {"table_id": table_id} + param_types = {"table_id": spanner_param_types.STRING} + + indexes = [] + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql( + sql_query, params=params, param_types=param_types + ) + for row in result_set: + index_info = {} + index_info["INDEX_NAME"] = row[0] + index_info["TABLE_SCHEMA"] = row[1] + index_info["INDEX_TYPE"] = row[2] + index_info["PARENT_TABLE_NAME"] = row[3] + index_info["IS_UNIQUE"] = row[4] + index_info["IS_NULL_FILTERED"] = row[5] + index_info["INDEX_STATE"] = row[6] + + try: + json.dumps(index_info) + except: + index_info = str(index_info) + + indexes.append(index_info) + + return {"status": "SUCCESS", "results": indexes} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_table_index_columns( + project_id: str, + instance_id: str, + database_id: str, + table_id: str, + credentials: Credentials, +) -> dict: + """Get the columns in each index of a Spanner table. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + table_id (str): The Spanner table id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary with a list of Spanner table index column + information. + + Examples: + >>> get_table_index_columns("my_project", "my_instance", "my_database", + ... "my_table") + { + "status": "SUCCESS", + "results": [ + { + 'INDEX_NAME': 'PRIMARY_KEY', + 'TABLE_SCHEMA': '', + 'COLUMN_NAME': 'ColumnKey1', + 'ORDINAL_POSITION': 1, + 'IS_NULLABLE': 'NO', + 'SPANNER_TYPE': 'STRING(MAX)' + }, + { + 'INDEX_NAME': 'PRIMARY_KEY', + 'TABLE_SCHEMA': '', + 'COLUMN_NAME': 'ColumnKey2', + 'ORDINAL_POSITION': 2, + 'IS_NULLABLE': 'NO', + 'SPANNER_TYPE': 'INT64' + }, + { + 'INDEX_NAME': 'IDX_MyTable_Column_FC70CD41F3A5FD3A', + 'TABLE_SCHEMA': '', + 'COLUMN_NAME': 'Column', + 'ORDINAL_POSITION': 3, + 'IS_NULLABLE': 'NO', + 'SPANNER_TYPE': 'STRING(MAX)' + } + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + sql_query = ( + "SELECT INDEX_NAME, TABLE_SCHEMA, COLUMN_NAME," + " ORDINAL_POSITION, IS_NULLABLE, SPANNER_TYPE " + "FROM INFORMATION_SCHEMA.INDEX_COLUMNS " + "WHERE TABLE_NAME = @table_id " # Use query parameter + ) + params = {"table_id": table_id} + param_types = {"table_id": spanner_param_types.STRING} + + index_columns = [] + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql( + sql_query, params=params, param_types=param_types + ) + for row in result_set: + index_column_info = {} + index_column_info["INDEX_NAME"] = row[0] + index_column_info["TABLE_SCHEMA"] = row[1] + index_column_info["COLUMN_NAME"] = row[2] + index_column_info["ORDINAL_POSITION"] = row[3] + index_column_info["IS_NULLABLE"] = row[4] + index_column_info["SPANNER_TYPE"] = row[5] + + try: + json.dumps(index_column_info) + except: + index_column_info = str(index_column_info) + + index_columns.append(index_column_info) + + return {"status": "SUCCESS", "results": index_columns} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } + + +def list_named_schemas( + project_id: str, + instance_id: str, + database_id: str, + credentials: Credentials, +) -> dict: + """Get the named schemas in the Spanner database. + + Args: + project_id (str): The Google Cloud project id. + instance_id (str): The Spanner instance id. + database_id (str): The Spanner database id. + credentials (Credentials): The credentials to use for the request. + + Returns: + dict: Dictionary with a list of named schemas information in the Spanner + database. + + Examples: + >>> list_named_schemas("my_project", "my_instance", "my_database") + { + "status": "SUCCESS", + "results": [ + "schema_1", + "schema_2" + ] + } + """ + try: + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + sql_query = """ + SELECT + SCHEMA_NAME + FROM + INFORMATION_SCHEMA.SCHEMATA + WHERE + SCHEMA_NAME NOT IN ('', 'INFORMATION_SCHEMA', 'SPANNER_SYS'); + """ + + named_schemas = [] + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql(sql_query) + for row in result_set: + named_schemas.append(row[0]) + + return {"status": "SUCCESS", "results": named_schemas} + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/spanner/query_tool.py b/src/google/adk/tools/spanner/query_tool.py new file mode 100644 index 00000000..e317a0ce --- /dev/null +++ b/src/google/adk/tools/spanner/query_tool.py @@ -0,0 +1,114 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json + +from google.auth.credentials import Credentials +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect + +from . import client +from ..tool_context import ToolContext +from .settings import SpannerToolSettings + +DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 + + +def execute_sql( + project_id: str, + instance_id: str, + database_id: str, + query: str, + credentials: Credentials, + settings: SpannerToolSettings, + tool_context: ToolContext, +) -> dict: + """Run a Spanner Read-Only query in the spanner database and return the result. + + Args: + project_id (str): The GCP project id in which the spanner database + resides. + instance_id (str): The instance id of the spanner database. + database_id (str): The database id of the spanner database. + query (str): The Spanner SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The settings for the tool. + tool_context (ToolContext): The context for the tool. + + Returns: + dict: Dictionary with the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + + Examples: + Fetch data or insights from a table: + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) AS count FROM my_table") + { + "status": "SUCCESS", + "rows": [ + [100] + ] + } + + Note: + This is running with Read-Only Transaction for query that only read data. + """ + + try: + # Get Spanner client + spanner_client = client.get_spanner_client( + project=project_id, credentials=credentials + ) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + if database.database_dialect == DatabaseDialect.POSTGRESQL: + return { + "status": "ERROR", + "error_details": "PostgreSQL dialect is not supported.", + } + + with database.snapshot() as snapshot: + result_set = snapshot.execute_sql(query) + rows = [] + counter = ( + settings.max_executed_query_result_rows + if settings and settings.max_executed_query_result_rows > 0 + else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS + ) + for row in result_set: + try: + # if the json serialization of the row succeeds, use it as is + json.dumps(row) + except: + row = str(row) + + rows.append(row) + counter -= 1 + if counter <= 0: + break + + result = {"status": "SUCCESS", "rows": rows} + if counter <= 0: + result["result_is_likely_truncated"] = True + return result + except Exception as ex: + return { + "status": "ERROR", + "error_details": str(ex), + } diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py new file mode 100644 index 00000000..5d097258 --- /dev/null +++ b/src/google/adk/tools/spanner/settings.py @@ -0,0 +1,46 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum +from typing import List + +from pydantic import BaseModel + +from ...utils.feature_decorator import experimental + + +class Capabilities(Enum): + """Capabilities indicating what type of operation tools are allowed to be performed on Spanner.""" + + DATA_READ = 'data_read' + """Read only data operations tools are allowed.""" + + +@experimental('Tool settings defaults may have breaking change in the future.') +class SpannerToolSettings(BaseModel): + """Settings for Spanner tools.""" + + capabilities: List[Capabilities] = [ + Capabilities.DATA_READ, + ] + """Allowed capabilities for Spanner tools. + + By default, the tool will allow only read operations. This behaviour may + change in future versions. + """ + + max_executed_query_result_rows: int = 50 + """Maximum number of rows to return from a query result.""" diff --git a/src/google/adk/tools/spanner/spanner_credentials.py b/src/google/adk/tools/spanner/spanner_credentials.py new file mode 100644 index 00000000..69279a49 --- /dev/null +++ b/src/google/adk/tools/spanner/spanner_credentials.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from ...utils.feature_decorator import experimental +from .._google_credentials import BaseGoogleCredentialsConfig + +SPANNER_TOKEN_CACHE_KEY = "spanner_token_cache" +SPANNER_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/spanner.data"] + + +@experimental +class SpannerCredentialsConfig(BaseGoogleCredentialsConfig): + """Spanner Credentials Configuration for Google API tools (Experimental). + + Please do not use this in production, as it may be deprecated later. + """ + + def __post_init__(self) -> SpannerCredentialsConfig: + """Populate default scope if scopes is None.""" + super().__post_init__() + + if not self.scopes: + self.scopes = SPANNER_DEFAULT_SCOPE + + # Set the token cache key + self._token_cache_key = SPANNER_TOKEN_CACHE_KEY + + return self diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py new file mode 100644 index 00000000..97dfa9a8 --- /dev/null +++ b/src/google/adk/tools/spanner/spanner_toolset.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from google.adk.agents.readonly_context import ReadonlyContext +from typing_extensions import override + +from . import metadata_tool +from . import query_tool +from ...tools.base_tool import BaseTool +from ...tools.base_toolset import BaseToolset +from ...tools.base_toolset import ToolPredicate +from ...tools.google_tool import GoogleTool +from ...utils.feature_decorator import experimental +from .settings import Capabilities +from .settings import SpannerToolSettings +from .spanner_credentials import SpannerCredentialsConfig + + +@experimental +class SpannerToolset(BaseToolset): + """Spanner Toolset contains tools for interacting with Spanner data, database and table information.""" + + def __init__( + self, + *, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + credentials_config: Optional[SpannerCredentialsConfig] = None, + spanner_tool_settings: Optional[SpannerToolSettings] = None, + ): + self.tool_filter = tool_filter + self._credentials_config = credentials_config + self._tool_settings = ( + spanner_tool_settings + if spanner_tool_settings + else SpannerToolSettings() + ) + + def _is_tool_selected( + self, tool: BaseTool, readonly_context: ReadonlyContext + ) -> bool: + if self.tool_filter is None: + return True + + if isinstance(self.tool_filter, ToolPredicate): + return self.tool_filter(tool, readonly_context) + + if isinstance(self.tool_filter, list): + return tool.name in self.tool_filter + + return False + + @override + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> List[BaseTool]: + """Get tools from the toolset.""" + all_tools = [ + GoogleTool( + func=func, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + for func in [ + # Metadata tools + metadata_tool.list_table_names, + metadata_tool.list_table_indexes, + metadata_tool.list_table_index_columns, + metadata_tool.list_named_schemas, + metadata_tool.get_table_schema, + ] + ] + + # Query tools + if ( + self._tool_settings + and Capabilities.DATA_READ in self._tool_settings.capabilities + ): + all_tools.append( + GoogleTool( + func=query_tool.execute_sql, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + ) + + return [ + tool + for tool in all_tools + if self._is_tool_selected(tool, readonly_context) + ] + + @override + async def close(self): + pass diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py index bf188ba8..2c52d1e6 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py @@ -74,8 +74,8 @@ def test_ask_data_insights_success(mock_get_stream): # 2. Create mock inputs for the function call mock_creds = mock.Mock() mock_creds.token = "fake-token" - mock_config = mock.Mock() - mock_config.max_query_result_rows = 100 + mock_settings = mock.Mock() + mock_settings.max_query_result_rows = 100 # 3. Call the function under test result = data_insights_tool.ask_data_insights( @@ -83,7 +83,7 @@ def test_ask_data_insights_success(mock_get_stream): user_query_with_context="test query", table_references=[], credentials=mock_creds, - config=mock_config, + settings=mock_settings, ) # 4. Assert the results are as expected @@ -101,7 +101,7 @@ def test_ask_data_insights_handles_exception(mock_get_stream): # 2. Create mock inputs mock_creds = mock.Mock() mock_creds.token = "fake-token" - mock_config = mock.Mock() + mock_settings = mock.Mock() # 3. Call the function result = data_insights_tool.ask_data_insights( @@ -109,7 +109,7 @@ def test_ask_data_insights_handles_exception(mock_get_stream): user_query_with_context="test query", table_references=[], credentials=mock_creds, - config=mock_config, + settings=mock_settings, ) # 4. Assert that the error was caught and formatted correctly diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index f0e673da..fe76a309 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -37,7 +37,7 @@ import pytest async def get_tool( - name: str, tool_config: Optional[BigQueryToolConfig] = None + name: str, tool_settings: Optional[BigQueryToolConfig] = None ) -> BaseTool: """Get a tool from BigQuery toolset. @@ -54,7 +54,7 @@ async def get_tool( toolset = BigQueryToolset( credentials_config=credentials_config, tool_filter=[name], - bigquery_tool_config=tool_config, + bigquery_tool_config=tool_settings, ) tools = await toolset.get_tools() @@ -64,7 +64,7 @@ async def get_tool( @pytest.mark.parametrize( - ("tool_config",), + ("tool_settings",), [ pytest.param(None, id="no-config"), pytest.param(BigQueryToolConfig(), id="default-config"), @@ -75,14 +75,14 @@ async def get_tool( ], ) @pytest.mark.asyncio -async def test_execute_sql_declaration_read_only(tool_config): +async def test_execute_sql_declaration_read_only(tool_settings): """Test BigQuery execute_sql tool declaration in read-only mode. This test verifies that the execute_sql tool declaration reflects the read-only capability. """ tool_name = "execute_sql" - tool = await get_tool(tool_name, tool_config) + tool = await get_tool(tool_name, tool_settings) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -92,7 +92,7 @@ async def test_execute_sql_declaration_read_only(tool_config): executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -127,7 +127,7 @@ async def test_execute_sql_declaration_read_only(tool_config): @pytest.mark.parametrize( - ("tool_config",), + ("tool_settings",), [ pytest.param( BigQueryToolConfig(write_mode=WriteMode.ALLOWED), @@ -136,14 +136,14 @@ async def test_execute_sql_declaration_read_only(tool_config): ], ) @pytest.mark.asyncio -async def test_execute_sql_declaration_write(tool_config): +async def test_execute_sql_declaration_write(tool_settings): """Test BigQuery execute_sql tool declaration with all writes enabled. This test verifies that the execute_sql tool declaration reflects the write capability. """ tool_name = "execute_sql" - tool = await get_tool(tool_name, tool_config) + tool = await get_tool(tool_name, tool_settings) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -153,7 +153,7 @@ async def test_execute_sql_declaration_write(tool_config): executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -326,7 +326,7 @@ async def test_execute_sql_declaration_write(tool_config): @pytest.mark.parametrize( - ("tool_config",), + ("tool_settings",), [ pytest.param( BigQueryToolConfig(write_mode=WriteMode.PROTECTED), @@ -335,14 +335,14 @@ async def test_execute_sql_declaration_write(tool_config): ], ) @pytest.mark.asyncio -async def test_execute_sql_declaration_protected_write(tool_config): +async def test_execute_sql_declaration_protected_write(tool_settings): """Test BigQuery execute_sql tool declaration with protected writes enabled. This test verifies that the execute_sql tool declaration reflects the protected write capability. """ tool_name = "execute_sql" - tool = await get_tool(tool_name, tool_config) + tool = await get_tool(tool_name, tool_settings) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ Run a BigQuery or BigQuery ML SQL query in the project and return the result. @@ -352,7 +352,7 @@ async def test_execute_sql_declaration_protected_write(tool_config): executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. - config (BigQueryToolConfig): The configuration for the tool. + settings (BigQueryToolConfig): The settings for the tool. tool_context (ToolContext): The context for the tool. Returns: @@ -530,7 +530,7 @@ def test_execute_sql_select_stmt(write_mode): statement_type = "SELECT" query_result = [{"num": 123}] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=write_mode) + tool_settings = BigQueryToolConfig(write_mode=write_mode) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -550,7 +550,9 @@ def test_execute_sql_select_stmt(write_mode): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -586,7 +588,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) tool_context = mock.create_autospec(ToolContext, instance=True) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: @@ -602,7 +604,9 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -638,7 +642,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) tool_context = mock.create_autospec(ToolContext, instance=True) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: @@ -654,7 +658,9 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == { "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", @@ -693,7 +699,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type): project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -714,7 +720,9 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -756,7 +764,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target( project = "my_project" query_result = [] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + tool_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -777,7 +785,9 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target( bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == { "status": "ERROR", "error_details": ( @@ -808,7 +818,7 @@ def test_execute_sql_no_default_auth( statement_type = "SELECT" query_result = [{"num": 123}] credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig(write_mode=write_mode) + tool_settings = BigQueryToolConfig(write_mode=write_mode) tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = ( "test-bq-session-id", @@ -830,7 +840,7 @@ def test_execute_sql_no_default_auth( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql(project, query, credentials, tool_settings, tool_context) assert result == {"status": "SUCCESS", "rows": query_result} mock_default_auth.assert_not_called() @@ -959,7 +969,7 @@ def test_execute_sql_result_dtype( project = "my_project" statement_type = "SELECT" credentials = mock.create_autospec(Credentials, instance=True) - tool_config = BigQueryToolConfig() + tool_settings = BigQueryToolConfig() tool_context = mock.create_autospec(ToolContext, instance=True) # Simulate the result of query API @@ -971,5 +981,5 @@ def test_execute_sql_result_dtype( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = execute_sql(project, query, credentials, tool_settings, tool_context) assert result == {"status": "SUCCESS", "rows": tool_result_rows} diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index 4129dc51..8f21e1be 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -15,8 +15,9 @@ from __future__ import annotations from google.adk.tools.bigquery import BigQueryCredentialsConfig -from google.adk.tools.bigquery import BigQueryTool from google.adk.tools.bigquery import BigQueryToolset +from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.adk.tools.google_tool import GoogleTool import pytest @@ -30,12 +31,18 @@ async def test_bigquery_toolset_tools_default(): credentials_config = BigQueryCredentialsConfig( client_id="abc", client_secret="def" ) - toolset = BigQueryToolset(credentials_config=credentials_config) + toolset = BigQueryToolset( + credentials_config=credentials_config, bigquery_tool_config=None + ) + # Verify that the tool config is initialized to default values. + assert isinstance(toolset._tool_settings, BigQueryToolConfig) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == BigQueryToolConfig().__dict__ # pylint: disable=protected-access + tools = await toolset.get_tools() assert tools is not None assert len(tools) == 5 - assert all([isinstance(tool, BigQueryTool) for tool in tools]) + assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ "list_dataset_ids", @@ -77,7 +84,7 @@ async def test_bigquery_toolset_tools_selective(selected_tools): assert tools is not None assert len(tools) == len(selected_tools) - assert all([isinstance(tool, BigQueryTool) for tool in tools]) + assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set(selected_tools) actual_tool_names = set([tool.name for tool in tools]) @@ -114,7 +121,7 @@ async def test_bigquery_toolset_unknown_tool(selected_tools, returned_tools): assert tools is not None assert len(tools) == len(returned_tools) - assert all([isinstance(tool, BigQueryTool) for tool in tools]) + assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set(returned_tools) actual_tool_names = set([tool.name for tool in tools]) diff --git a/tests/unittests/tools/spanner/__init__ b/tests/unittests/tools/spanner/__init__ new file mode 100644 index 00000000..60cac4f4 --- /dev/null +++ b/tests/unittests/tools/spanner/__init__ @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/tests/unittests/tools/spanner/test_spanner_client.py b/tests/unittests/tools/spanner/test_spanner_client.py new file mode 100644 index 00000000..0aaf6967 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_client.py @@ -0,0 +1,142 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import re +from unittest import mock + +from google.adk.tools.spanner.client import get_spanner_client +from google.auth.exceptions import DefaultCredentialsError +from google.oauth2.credentials import Credentials +import pytest + + +def test_spanner_client_project(): + """Test spanner client project.""" + # Trigger the spanner client creation + client = get_spanner_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the client has the desired project set + assert client.project == "test-gcp-project" + + +def test_spanner_client_project_set_explicit(): + """Test spanner client creation does not invoke default auth.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the spanner client creation + client = get_spanner_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_spanner_client_project_set_with_default_auth(): + """Test spanner client creation invokes default auth to set the project.""" + # Let's simulate that no environment variables are set, so that any project + # set in there does not interfere with this test + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate credentials + mock_creds = mock.create_autospec(Credentials, instance=True) + + # Simulate output of the default auth + mock_default_auth.return_value = (mock_creds, "test-gcp-project") + + # Trigger the spanner client creation + client = get_spanner_client( + project=None, + credentials=mock_creds, + ) + + # Verify that default auth was called once to set the client project + mock_default_auth.assert_called_once() + assert client.project == "test-gcp-project" + + +def test_spanner_client_project_set_with_env(): + """Test spanner client creation sets the project from environment variable.""" + # Let's simulate the project set in environment variables + with mock.patch.dict( + os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True + ): + with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + # Simulate exception from default auth + mock_default_auth.side_effect = DefaultCredentialsError( + "Your default credentials were not found" + ) + + # Trigger the spanner client creation + client = get_spanner_client( + project=None, + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # If we are here that already means client creation did not call default + # auth (otherwise we would have run into DefaultCredentialsError set + # above). For the sake of explicitness, trivially assert that the default + # auth was not called, and yet the project was set correctly + mock_default_auth.assert_not_called() + assert client.project == "test-gcp-project" + + +def test_spanner_client_user_agent(): + """Test spanner client user agent.""" + # Patch the Client constructor + with mock.patch( + "google.cloud.spanner.Client", autospec=True + ) as mock_client_class: + # The mock instance that will be returned by spanner.Client() + mock_instance = mock_client_class.return_value + # The real spanner.Client instance has a `_client_info` attribute. + # We need to add it to our mock instance so that the user_agent can be set. + mock_instance._client_info = mock.Mock() + + # Call the function that creates the client + client = get_spanner_client( + project="test-gcp-project", + credentials=mock.create_autospec(Credentials, instance=True), + ) + + # Verify that the Spanner Client was instantiated. + mock_client_class.assert_called_once_with( + project="test-gcp-project", + credentials=mock.ANY, + ) + + # Verify that the user_agent was set on the client instance. + # The client returned by get_spanner_client is the mock instance. + assert re.search( + r"adk-spanner-tool google-adk/([0-9A-Za-z._\-+/]+)", + client._client_info.user_agent, + ) diff --git a/tests/unittests/tools/spanner/test_spanner_credentials.py b/tests/unittests/tools/spanner/test_spanner_credentials.py new file mode 100644 index 00000000..19430e14 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_credentials.py @@ -0,0 +1,54 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +# Mock the Google OAuth and API dependencies +import google.auth.credentials +import google.oauth2.credentials +import pytest + + +class TestSpannerCredentials: + """Test suite for Spanner credentials configuration validation. + + This class tests the credential configuration logic that ensures + either existing credentials or client ID/secret pairs are provided. + """ + + def test_valid_credentials_object_oauth2_credentials(self): + """Test that providing valid Credentials object works correctly with google.oauth2.credentials.Credentials. + + When a user already has valid OAuth credentials, they should be able + to pass them directly without needing to provide client ID/secret. + """ + # Create a mock oauth2 credentials object + oauth2_creds = google.oauth2.credentials.Credentials( + "test_token", + client_id="test_client_id", + client_secret="test_client_secret", + scopes=[], + ) + + config = SpannerCredentialsConfig(credentials=oauth2_creds) + + # Verify that the credentials are properly stored and attributes are + # extracted + assert config.credentials == oauth2_creds + assert config.client_id == "test_client_id" + assert config.client_secret == "test_client_secret" + assert config.scopes == [ + "https://www.googleapis.com/auth/spanner.data", + ] + + assert config._token_cache_key == "spanner_token_cache" # pylint: disable=protected-access diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py new file mode 100644 index 00000000..f74922b2 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py @@ -0,0 +1,27 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.spanner.settings import SpannerToolSettings +import pytest + + +def test_spanner_tool_settings_experimental_warning(): + """Test SpannerToolSettings experimental warning.""" + with pytest.warns( + UserWarning, + match="Tool settings defaults may have breaking change in the future.", + ): + SpannerToolSettings() diff --git a/tests/unittests/tools/spanner/test_spanner_toolset.py b/tests/unittests/tools/spanner/test_spanner_toolset.py new file mode 100644 index 00000000..73a780f8 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_toolset.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner import SpannerCredentialsConfig +from google.adk.tools.spanner import SpannerToolset +from google.adk.tools.spanner.settings import SpannerToolSettings +import pytest + + +@pytest.mark.asyncio +async def test_spanner_toolset_tools_default(): + """Test default Spanner toolset. + + This test verifies the behavior of the Spanner toolset when no filter is + specified. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = SpannerToolset(credentials_config=credentials_config) + assert isinstance(toolset._tool_settings, SpannerToolSettings) # pylint: disable=protected-access + assert toolset._tool_settings.__dict__ == SpannerToolSettings().__dict__ # pylint: disable=protected-access + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 6 + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set([ + "list_table_names", + "list_table_indexes", + "list_table_index_columns", + "list_named_schemas", + "get_table_schema", + "execute_sql", + ]) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + "selected_tools", + [ + pytest.param([], id="None"), + pytest.param( + ["list_table_names", "get_table_schema"], + id="table-metadata", + ), + pytest.param(["execute_sql"], id="query"), + ], +) +@pytest.mark.asyncio +async def test_spanner_toolset_selective(selected_tools): + """Test selective Spanner toolset. + + This test verifies the behavior of the Spanner toolset when a filter is + specified. + + Args: + selected_tools: A list of tool names to filter. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=selected_tools, + spanner_tool_settings=SpannerToolSettings(), + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(selected_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(selected_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param(["unknown"], [], id="all-unknown"), + pytest.param( + ["unknown", "execute_sql"], + ["execute_sql"], + id="mixed-known-unknown", + ), + ], +) +@pytest.mark.asyncio +async def test_spanner_toolset_unknown_tool(selected_tools, returned_tools): + """Test Spanner toolset with unknown tools. + + This test verifies the behavior of the Spanner toolset when unknown tools are + specified in the filter. + + Args: + selected_tools: A list of tool names to filter, including unknown ones. + returned_tools: A list of tool names that are expected to be returned. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=selected_tools, + spanner_tool_settings=SpannerToolSettings(), + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(returned_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names + + +@pytest.mark.parametrize( + ("selected_tools", "returned_tools"), + [ + pytest.param( + ["execute_sql", "list_table_names"], + ["list_table_names"], + id="read-not-added", + ), + pytest.param( + ["list_table_names", "list_table_indexes"], + ["list_table_names", "list_table_indexes"], + id="no-effect", + ), + ], +) +@pytest.mark.asyncio +async def test_spanner_toolset_without_read_capability( + selected_tools, returned_tools +): + """Test Spanner toolset without read capability. + + This test verifies the behavior of the Spanner toolset when read capability is + not enabled. + + Args: + selected_tools: A list of tool names to filter. + returned_tools: A list of tool names that are expected to be returned. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + spanner_tool_settings = SpannerToolSettings(capabilities=[]) + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=selected_tools, + spanner_tool_settings=spanner_tool_settings, + ) + + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == len(returned_tools) + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set(returned_tools) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py b/tests/unittests/tools/test_base_google_credentials_manager.py similarity index 96% rename from tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py rename to tests/unittests/tools/test_base_google_credentials_manager.py index 73ffa3bd..de568543 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials_manager.py +++ b/tests/unittests/tools/test_base_google_credentials_manager.py @@ -18,9 +18,9 @@ from unittest.mock import Mock from unittest.mock import patch from google.adk.auth.auth_tool import AuthConfig +from google.adk.tools._google_credentials import GoogleCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BIGQUERY_TOKEN_CACHE_KEY from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig -from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials as AuthCredentials from google.auth.exceptions import RefreshError @@ -29,8 +29,8 @@ from google.oauth2.credentials import Credentials as OAuthCredentials import pytest -class TestBigQueryCredentialsManager: - """Test suite for BigQueryCredentialsManager OAuth flow handling. +class TestGoogleCredentialsManager: + """Test suite for GoogleCredentialsManager OAuth flow handling. This class tests the complex credential management logic including credential validation, refresh, OAuth flow orchestration, and the @@ -63,7 +63,7 @@ class TestBigQueryCredentialsManager: @pytest.fixture def manager(self, credentials_config): """Create a credentials manager instance for testing.""" - return BigQueryCredentialsManager(credentials_config) + return GoogleCredentialsManager(credentials_config) @pytest.mark.parametrize( ("credentials_class",), @@ -336,7 +336,7 @@ class TestBigQueryCredentialsManager: # Use the full module path as it appears in the project structure with patch( - "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", + "google.adk.tools._google_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: result = await manager.get_valid_credentials(mock_tool_context) @@ -388,7 +388,7 @@ class TestBigQueryCredentialsManager: credential manager, avoiding redundant OAuth flows. """ # Create first manager instance and simulate OAuth completion - manager1 = BigQueryCredentialsManager(credentials_config) + manager1 = GoogleCredentialsManager(credentials_config) # Mock OAuth response for first manager mock_auth_response = Mock() @@ -412,7 +412,7 @@ class TestBigQueryCredentialsManager: # Use the correct module path - without the 'src.' prefix with patch( - "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials", + "google.adk.tools._google_credentials.google.oauth2.credentials.Credentials", return_value=mock_creds, ) as mock_credentials_class: # Complete OAuth flow with first manager @@ -424,7 +424,7 @@ class TestBigQueryCredentialsManager: assert cached_creds_json == mock_creds_json # Create second manager instance (simulating new request/session) - manager2 = BigQueryCredentialsManager(credentials_config) + manager2 = GoogleCredentialsManager(credentials_config) credentials_config.credentials = None # Reset auth response to None (no new OAuth flow available) @@ -432,7 +432,7 @@ class TestBigQueryCredentialsManager: # Mock the from_authorized_user_info method for the second manager with patch( - "google.adk.tools.bigquery.bigquery_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" + "google.adk.tools._google_credentials.google.oauth2.credentials.Credentials.from_authorized_user_info" ) as mock_from_json: mock_cached_creds = Mock(spec=OAuthCredentials) mock_cached_creds.valid = True diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool.py b/tests/unittests/tools/test_google_tool.py similarity index 84% rename from tests/unittests/tools/bigquery/test_bigquery_tool.py rename to tests/unittests/tools/test_google_tool.py index 5b1441d4..fb9da070 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool.py +++ b/tests/unittests/tools/test_google_tool.py @@ -16,18 +16,19 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.tools._google_credentials import GoogleCredentialsManager from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig -from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsManager -from google.adk.tools.bigquery.bigquery_tool import BigQueryTool from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.adk.tools.google_tool import GoogleTool +from google.adk.tools.spanner.settings import SpannerToolSettings from google.adk.tools.tool_context import ToolContext # Mock the Google OAuth and API dependencies from google.oauth2.credentials import Credentials import pytest -class TestBigQueryTool: - """Test suite for BigQueryTool OAuth integration and execution. +class TestGoogleTool: + """Test suite for GoogleTool OAuth integration and execution. This class tests the high-level tool execution logic that combines credential management with actual function execution. @@ -88,18 +89,18 @@ class TestBigQueryTool: def test_tool_initialization_with_credentials( self, sample_function, credentials_config ): - """Test that BigQueryTool initializes correctly with credentials. + """Test that GoogleTool initializes correctly with credentials. The tool should properly inherit from FunctionTool while adding Google API specific credential management capabilities. """ - tool = BigQueryTool( + tool = GoogleTool( func=sample_function, credentials_config=credentials_config ) assert tool.func == sample_function assert tool._credentials_manager is not None - assert isinstance(tool._credentials_manager, BigQueryCredentialsManager) + assert isinstance(tool._credentials_manager, GoogleCredentialsManager) # Verify that 'credentials' parameter is ignored in function signature analysis assert "credentials" in tool._ignore_params @@ -109,7 +110,7 @@ class TestBigQueryTool: Some tools might handle authentication externally or use service accounts, so credential management should be optional. """ - tool = BigQueryTool(func=sample_function, credentials_config=None) + tool = GoogleTool(func=sample_function, credentials_config=None) assert tool.func == sample_function assert tool._credentials_manager is None @@ -123,7 +124,7 @@ class TestBigQueryTool: This tests the main happy path where credentials are available and the underlying function executes successfully. """ - tool = BigQueryTool( + tool = GoogleTool( func=sample_function, credentials_config=credentials_config ) @@ -152,7 +153,7 @@ class TestBigQueryTool: When credentials aren't available and OAuth flow is needed, the tool should return a user-friendly message rather than failing. """ - tool = BigQueryTool( + tool = GoogleTool( func=sample_function, credentials_config=credentials_config ) @@ -178,7 +179,7 @@ class TestBigQueryTool: Tools without credential managers should execute normally, passing None for credentials if the function accepts them. """ - tool = BigQueryTool(func=sample_function, credentials_config=None) + tool = GoogleTool(func=sample_function, credentials_config=None) result = await tool.run_async( args={"param1": "test_value"}, tool_context=mock_tool_context @@ -196,7 +197,7 @@ class TestBigQueryTool: The tool should correctly detect and execute async functions, which is important for tools that make async API calls. """ - tool = BigQueryTool( + tool = GoogleTool( func=async_sample_function, credentials_config=credentials_config ) @@ -227,7 +228,7 @@ class TestBigQueryTool: def failing_function(param1: str, credentials: Credentials = None) -> dict: raise ValueError("Something went wrong") - tool = BigQueryTool( + tool = GoogleTool( func=failing_function, credentials_config=credentials_config ) @@ -259,7 +260,7 @@ class TestBigQueryTool: ) -> dict: return {"success": True} - tool = BigQueryTool( + tool = GoogleTool( func=complex_function, credentials_config=credentials_config ) @@ -270,7 +271,7 @@ class TestBigQueryTool: assert "optional_param" not in mandatory_args @pytest.mark.parametrize( - "input_config, expected_config", + "input_settings, expected_settings", [ pytest.param( BigQueryToolConfig( @@ -281,22 +282,36 @@ class TestBigQueryTool: ), id="with_provided_config", ), - pytest.param( - None, - BigQueryToolConfig(), - id="with_none_config_creates_default", - ), ], ) - def test_tool_config_initialization(self, input_config, expected_config): - """Tests that self._tool_config is correctly initialized by comparing its + def test_tool_bigquery_config_initialization( + self, input_settings, expected_settings + ): + """Tests that self._tool_settings is correctly initialized by comparing its final state to an expected configuration object. """ # 1. Initialize the tool with the parameterized config - tool = BigQueryTool(func=None, bigquery_tool_config=input_config) + tool = GoogleTool(func=None, tool_settings=input_settings) # 2. Assert that the tool's config has the same attribute values # as the expected config. Comparing the __dict__ is a robust # way to check for value equality. - assert tool._tool_config.__dict__ == expected_config.__dict__ # pylint: disable=protected-access + assert tool._tool_settings.__dict__ == expected_settings.__dict__ # pylint: disable=protected-access + + @pytest.mark.parametrize( + "input_settings, expected_settings", + [ + pytest.param( + SpannerToolSettings(max_executed_query_result_rows=10), + SpannerToolSettings(max_executed_query_result_rows=10), + id="with_provided_settings", + ), + ], + ) + def test_tool_spanner_settings_initialization( + self, input_settings, expected_settings + ): + """Tests that self._tool_settings is correctly initialized with SpannerToolSettings by comparing its final state to an expected configuration object.""" + tool = GoogleTool(func=None, tool_settings=input_settings) + assert tool._tool_settings.__dict__ == expected_settings.__dict__ # pylint: disable=protected-access