diff --git a/contributing/samples/spanner/agent.py b/contributing/samples/spanner/agent.py index 631fb45b..065cf027 100644 --- a/contributing/samples/spanner/agent.py +++ b/contributing/samples/spanner/agent.py @@ -18,6 +18,7 @@ from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes from google.adk.tools.google_tool import GoogleTool from google.adk.tools.spanner.settings import Capabilities +from google.adk.tools.spanner.settings import QueryResultMode from google.adk.tools.spanner.settings import SpannerToolSettings from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig from google.adk.tools.spanner.spanner_toolset import SpannerToolset @@ -34,7 +35,10 @@ CREDENTIALS_TYPE = None # Define Spanner tool config with read capability set to allowed. -tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ]) +tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + query_result_mode=QueryResultMode.DICT_LIST, +) if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: # Initialize the tools to do interactive OAuth diff --git a/src/google/adk/tools/spanner/query_tool.py b/src/google/adk/tools/spanner/query_tool.py index f6fa5a34..51fe2df9 100644 --- a/src/google/adk/tools/spanner/query_tool.py +++ b/src/google/adk/tools/spanner/query_tool.py @@ -14,10 +14,16 @@ from __future__ import annotations +import functools +import textwrap +import types +from typing import Callable + from google.auth.credentials import Credentials from . import utils from ..tool_context import ToolContext +from .settings import QueryResultMode from .settings import SpannerToolSettings @@ -49,16 +55,29 @@ def execute_sql( query not returned in the result. Examples: - Fetch data or insights from a table: + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) AS count FROM my_table") + { + "status": "SUCCESS", + "rows": [ + [100] + ] + } + - >>> execute_sql("my_project", "my_instance", "my_database", - ... "SELECT COUNT(*) AS count FROM my_table") - { - "status": "SUCCESS", - "rows": [ - [100] - ] - } + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT name, rating, description FROM hotels_table") + { + "status": "SUCCESS", + "rows": [ + ["The Hotel", 4.1, "Modern hotel."], + ["Park Inn", 4.5, "Cozy hotel."], + ... + ] + } + Note: This is running with Read-Only Transaction for query that only read data. @@ -72,3 +91,105 @@ def execute_sql( settings, tool_context, ) + + +_EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING = textwrap.dedent("""\ +Run a Spanner Read-Only query in the spanner database and return the result. + +Args: + project_id (str): The GCP project id in which the spanner database + resides. + instance_id (str): The instance id of the spanner database. + database_id (str): The database id of the spanner database. + query (str): The Spanner SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The settings for the tool. + tool_context (ToolContext): The context for the tool. + +Returns: + dict: Dictionary with the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + +Examples: + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) AS count FROM my_table") + { + "status": "SUCCESS", + "rows": [ + { + "count": 100 + } + ] + } + + + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) FROM my_table") + { + "status": "SUCCESS", + "rows": [ + { + "": 100 + } + ] + } + + + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT name, rating, description FROM hotels_table") + { + "status": "SUCCESS", + "rows": [ + { + "name": "The Hotel", + "rating": 4.1, + "description": "Modern hotel." + }, + { + "name": "Park Inn", + "rating": 4.5, + "description": "Cozy hotel." + }, + ... + ] + } + + +Note: + This is running with Read-Only Transaction for query that only read data. +""") + + +def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]: + """Get the execute_sql tool customized as per the given tool settings. + + Args: + settings: Spanner tool settings indicating the behavior of the execute_sql + tool. + + Returns: + callable[..., dict]: A version of the execute_sql tool respecting the tool + settings. + """ + + if settings and settings.query_result_mode is QueryResultMode.DICT_LIST: + + execute_sql_wrapper = types.FunctionType( + execute_sql.__code__, + execute_sql.__globals__, + execute_sql.__name__, + execute_sql.__defaults__, + execute_sql.__closure__, + ) + functools.update_wrapper(execute_sql_wrapper, execute_sql) + # Update with the new docstring + execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING + return execute_sql_wrapper + + # Return the default execute_sql function. + return execute_sql diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index a76331ba..3f58d20f 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -40,6 +40,20 @@ class Capabilities(Enum): """Read only data operations tools are allowed.""" +class QueryResultMode(Enum): + """Settings for Spanner execute sql query result.""" + + DEFAULT = "default" + """Return the result of a query as a list of rows data.""" + + DICT_LIST = "dict_list" + """Return the result of a query as a list of dictionaries. + + In each dictionary the key is the column name and the value is the value of + the that column in a given row. + """ + + class SpannerVectorStoreSettings(BaseModel): """Settings for Spanner Vector Store. @@ -140,5 +154,8 @@ class SpannerToolSettings(BaseModel): max_executed_query_result_rows: int = 50 """Maximum number of rows to return from a query result.""" + query_result_mode: QueryResultMode = QueryResultMode.DEFAULT + """Mode for Spanner execute sql query result.""" + vector_store_settings: Optional[SpannerVectorStoreSettings] = None """Settings for Spanner vector store and vector similarity search.""" diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py index 6496014f..20800263 100644 --- a/src/google/adk/tools/spanner/spanner_toolset.py +++ b/src/google/adk/tools/spanner/spanner_toolset.py @@ -111,7 +111,7 @@ class SpannerToolset(BaseToolset): ): all_tools.append( GoogleTool( - func=query_tool.execute_sql, + func=query_tool.get_execute_sql(self._tool_settings), credentials_config=self._credentials_config, tool_settings=self._tool_settings, ) diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py index f1b710ec..adde5219 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -22,6 +22,7 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from . import client from ..tool_context import ToolContext +from .settings import QueryResultMode from .settings import SpannerToolSettings DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 @@ -84,6 +85,9 @@ def execute_sql( if settings and settings.max_executed_query_result_rows > 0 else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS ) + if settings and settings.query_result_mode is QueryResultMode.DICT_LIST: + result_set = result_set.to_dict_list() + for row in result_set: try: # if the json serialization of the row succeeds, use it as is diff --git a/tests/unittests/tools/spanner/test_spanner_query_tool.py b/tests/unittests/tools/spanner/test_spanner_query_tool.py new file mode 100644 index 00000000..73b3cb50 --- /dev/null +++ b/tests/unittests/tools/spanner/test_spanner_query_tool.py @@ -0,0 +1,224 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import textwrap +from unittest import mock + +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.spanner import query_tool +from google.adk.tools.spanner import settings +from google.adk.tools.spanner.settings import QueryResultMode +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig +from google.adk.tools.spanner.spanner_toolset import SpannerToolset +from google.adk.tools.tool_context import ToolContext +from google.auth.credentials import Credentials +import pytest + + +async def get_tool( + name: str, tool_settings: SpannerToolSettings | None = None +) -> BaseTool: + """Get a tool from Spanner toolset.""" + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + toolset = SpannerToolset( + credentials_config=credentials_config, + tool_filter=[name], + spanner_tool_settings=tool_settings, + ) + + tools = await toolset.get_tools() + assert tools is not None + assert len(tools) == 1 + return tools[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "query_result_mode, expected_description", + [ + ( + QueryResultMode.DEFAULT, + textwrap.dedent( + """\ + Run a Spanner Read-Only query in the spanner database and return the result. + + Args: + project_id (str): The GCP project id in which the spanner database + resides. + instance_id (str): The instance id of the spanner database. + database_id (str): The database id of the spanner database. + query (str): The Spanner SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The settings for the tool. + tool_context (ToolContext): The context for the tool. + + Returns: + dict: Dictionary with the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + + Examples: + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) AS count FROM my_table") + { + "status": "SUCCESS", + "rows": [ + [100] + ] + } + + + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT name, rating, description FROM hotels_table") + { + "status": "SUCCESS", + "rows": [ + ["The Hotel", 4.1, "Modern hotel."], + ["Park Inn", 4.5, "Cozy hotel."], + ... + ] + } + + + Note: + This is running with Read-Only Transaction for query that only read data.""" + ), + ), + ( + QueryResultMode.DICT_LIST, + textwrap.dedent( + """\ + Run a Spanner Read-Only query in the spanner database and return the result. + + Args: + project_id (str): The GCP project id in which the spanner database + resides. + instance_id (str): The instance id of the spanner database. + database_id (str): The database id of the spanner database. + query (str): The Spanner SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The settings for the tool. + tool_context (ToolContext): The context for the tool. + + Returns: + dict: Dictionary with the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + + Examples: + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) AS count FROM my_table") + { + "status": "SUCCESS", + "rows": [ + { + "count": 100 + } + ] + } + + + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT COUNT(*) FROM my_table") + { + "status": "SUCCESS", + "rows": [ + { + "": 100 + } + ] + } + + + + >>> execute_sql("my_project", "my_instance", "my_database", + ... "SELECT name, rating, description FROM hotels_table") + { + "status": "SUCCESS", + "rows": [ + { + "name": "The Hotel", + "rating": 4.1, + "description": "Modern hotel." + }, + { + "name": "Park Inn", + "rating": 4.5, + "description": "Cozy hotel." + }, + ... + ] + } + + + Note: + This is running with Read-Only Transaction for query that only read data.""" + ), + ), + ], +) +async def test_execute_sql_query_result( + query_result_mode, expected_description +): + """Test Spanner execute_sql tool query result in different modes.""" + tool_name = "execute_sql" + tool_settings = SpannerToolSettings(query_result_mode=query_result_mode) + tool = await get_tool(tool_name, tool_settings) + assert tool.name == tool_name + assert tool.description == expected_description + + +@mock.patch.object(query_tool.utils, "execute_sql", spec_set=True) +def test_execute_sql(mock_utils_execute_sql): + """Test execute_sql function in query result default mode.""" + mock_credentials = mock.create_autospec( + Credentials, instance=True, spec_set=True + ) + mock_tool_context = mock.create_autospec( + ToolContext, instance=True, spec_set=True + ) + mock_utils_execute_sql.return_value = {"status": "SUCCESS", "rows": [[1]]} + + result = query_tool.execute_sql( + project_id="test-project", + instance_id="test-instance", + database_id="test-database", + query="SELECT 1", + credentials=mock_credentials, + settings=settings.SpannerToolSettings(), + tool_context=mock_tool_context, + ) + + mock_utils_execute_sql.assert_called_once_with( + "test-project", + "test-instance", + "test-database", + "SELECT 1", + mock_credentials, + settings.SpannerToolSettings(), + mock_tool_context, + ) + assert result == {"status": "SUCCESS", "rows": [[1]]} diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py index 730c9e0e..1865f5f4 100644 --- a/tests/unittests/tools/spanner/test_spanner_tool_settings.py +++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py @@ -14,6 +14,8 @@ from __future__ import annotations +from google.adk.tools.spanner.settings import Capabilities +from google.adk.tools.spanner.settings import QueryResultMode from google.adk.tools.spanner.settings import SpannerToolSettings from google.adk.tools.spanner.settings import SpannerVectorStoreSettings from pydantic import ValidationError @@ -70,3 +72,26 @@ def test_spanner_vector_store_settings_invalid_vector_length(): assert "Invalid vector length in the Spanner vector store settings." in str( excinfo.value ) + + +@pytest.mark.parametrize( + "settings_args, expected_rows, expected_mode", + [ + ({}, 50, QueryResultMode.DEFAULT), + ( + { + "capabilities": [Capabilities.DATA_READ], + "max_executed_query_result_rows": 100, + "query_result_mode": QueryResultMode.DICT_LIST, + }, + 100, + QueryResultMode.DICT_LIST, + ), + ], +) +def test_spanner_tool_settings(settings_args, expected_rows, expected_mode): + """Test SpannerToolSettings with different values.""" + settings = SpannerToolSettings(**settings_args) + assert settings.capabilities == [Capabilities.DATA_READ] + assert settings.max_executed_query_result_rows == expected_rows + assert settings.query_result_mode == expected_mode