feat: add Spanner execute sql query result mode

Add using the execute sql query return result as list of dictionaries.
In each dictionary the key is the column name and the value is the value of
the that column in a given row.

PiperOrigin-RevId: 840909555
This commit is contained in:
Google Team Member
2025-12-05 16:00:09 -08:00
committed by Copybara-Service
parent de841a4a09
commit f22bac0b20
7 changed files with 406 additions and 11 deletions
+5 -1
View File
@@ -18,6 +18,7 @@ from google.adk.agents.llm_agent import LlmAgent
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.tools.google_tool import GoogleTool
from google.adk.tools.spanner.settings import Capabilities
from google.adk.tools.spanner.settings import QueryResultMode
from google.adk.tools.spanner.settings import SpannerToolSettings
from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig
from google.adk.tools.spanner.spanner_toolset import SpannerToolset
@@ -34,7 +35,10 @@ CREDENTIALS_TYPE = None
# Define Spanner tool config with read capability set to allowed.
tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ])
tool_settings = SpannerToolSettings(
capabilities=[Capabilities.DATA_READ],
query_result_mode=QueryResultMode.DICT_LIST,
)
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
# Initialize the tools to do interactive OAuth
+130 -9
View File
@@ -14,10 +14,16 @@
from __future__ import annotations
import functools
import textwrap
import types
from typing import Callable
from google.auth.credentials import Credentials
from . import utils
from ..tool_context import ToolContext
from .settings import QueryResultMode
from .settings import SpannerToolSettings
@@ -49,16 +55,29 @@ def execute_sql(
query not returned in the result.
Examples:
Fetch data or insights from a table:
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) AS count FROM my_table")
{
"status": "SUCCESS",
"rows": [
[100]
]
}
</Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) AS count FROM my_table")
{
"status": "SUCCESS",
"rows": [
[100]
]
}
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT name, rating, description FROM hotels_table")
{
"status": "SUCCESS",
"rows": [
["The Hotel", 4.1, "Modern hotel."],
["Park Inn", 4.5, "Cozy hotel."],
...
]
}
</Example>
Note:
This is running with Read-Only Transaction for query that only read data.
@@ -72,3 +91,105 @@ def execute_sql(
settings,
tool_context,
)
_EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING = textwrap.dedent("""\
Run a Spanner Read-Only query in the spanner database and return the result.
Args:
project_id (str): The GCP project id in which the spanner database
resides.
instance_id (str): The instance id of the spanner database.
database_id (str): The database id of the spanner database.
query (str): The Spanner SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
settings (SpannerToolSettings): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
dict: Dictionary with the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) AS count FROM my_table")
{
"status": "SUCCESS",
"rows": [
{
"count": 100
}
]
}
</Example>
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) FROM my_table")
{
"status": "SUCCESS",
"rows": [
{
"": 100
}
]
}
</Example>
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT name, rating, description FROM hotels_table")
{
"status": "SUCCESS",
"rows": [
{
"name": "The Hotel",
"rating": 4.1,
"description": "Modern hotel."
},
{
"name": "Park Inn",
"rating": 4.5,
"description": "Cozy hotel."
},
...
]
}
</Example>
Note:
This is running with Read-Only Transaction for query that only read data.
""")
def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]:
"""Get the execute_sql tool customized as per the given tool settings.
Args:
settings: Spanner tool settings indicating the behavior of the execute_sql
tool.
Returns:
callable[..., dict]: A version of the execute_sql tool respecting the tool
settings.
"""
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
execute_sql_wrapper = types.FunctionType(
execute_sql.__code__,
execute_sql.__globals__,
execute_sql.__name__,
execute_sql.__defaults__,
execute_sql.__closure__,
)
functools.update_wrapper(execute_sql_wrapper, execute_sql)
# Update with the new docstring
execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING
return execute_sql_wrapper
# Return the default execute_sql function.
return execute_sql
+17
View File
@@ -40,6 +40,20 @@ class Capabilities(Enum):
"""Read only data operations tools are allowed."""
class QueryResultMode(Enum):
"""Settings for Spanner execute sql query result."""
DEFAULT = "default"
"""Return the result of a query as a list of rows data."""
DICT_LIST = "dict_list"
"""Return the result of a query as a list of dictionaries.
In each dictionary the key is the column name and the value is the value of
the that column in a given row.
"""
class SpannerVectorStoreSettings(BaseModel):
"""Settings for Spanner Vector Store.
@@ -140,5 +154,8 @@ class SpannerToolSettings(BaseModel):
max_executed_query_result_rows: int = 50
"""Maximum number of rows to return from a query result."""
query_result_mode: QueryResultMode = QueryResultMode.DEFAULT
"""Mode for Spanner execute sql query result."""
vector_store_settings: Optional[SpannerVectorStoreSettings] = None
"""Settings for Spanner vector store and vector similarity search."""
@@ -111,7 +111,7 @@ class SpannerToolset(BaseToolset):
):
all_tools.append(
GoogleTool(
func=query_tool.execute_sql,
func=query_tool.get_execute_sql(self._tool_settings),
credentials_config=self._credentials_config,
tool_settings=self._tool_settings,
)
+4
View File
@@ -22,6 +22,7 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from . import client
from ..tool_context import ToolContext
from .settings import QueryResultMode
from .settings import SpannerToolSettings
DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50
@@ -84,6 +85,9 @@ def execute_sql(
if settings and settings.max_executed_query_result_rows > 0
else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS
)
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
result_set = result_set.to_dict_list()
for row in result_set:
try:
# if the json serialization of the row succeeds, use it as is
@@ -0,0 +1,224 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import textwrap
from unittest import mock
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.spanner import query_tool
from google.adk.tools.spanner import settings
from google.adk.tools.spanner.settings import QueryResultMode
from google.adk.tools.spanner.settings import SpannerToolSettings
from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig
from google.adk.tools.spanner.spanner_toolset import SpannerToolset
from google.adk.tools.tool_context import ToolContext
from google.auth.credentials import Credentials
import pytest
async def get_tool(
name: str, tool_settings: SpannerToolSettings | None = None
) -> BaseTool:
"""Get a tool from Spanner toolset."""
credentials_config = SpannerCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = SpannerToolset(
credentials_config=credentials_config,
tool_filter=[name],
spanner_tool_settings=tool_settings,
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 1
return tools[0]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"query_result_mode, expected_description",
[
(
QueryResultMode.DEFAULT,
textwrap.dedent(
"""\
Run a Spanner Read-Only query in the spanner database and return the result.
Args:
project_id (str): The GCP project id in which the spanner database
resides.
instance_id (str): The instance id of the spanner database.
database_id (str): The database id of the spanner database.
query (str): The Spanner SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
settings (SpannerToolSettings): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
dict: Dictionary with the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) AS count FROM my_table")
{
"status": "SUCCESS",
"rows": [
[100]
]
}
</Example>
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT name, rating, description FROM hotels_table")
{
"status": "SUCCESS",
"rows": [
["The Hotel", 4.1, "Modern hotel."],
["Park Inn", 4.5, "Cozy hotel."],
...
]
}
</Example>
Note:
This is running with Read-Only Transaction for query that only read data."""
),
),
(
QueryResultMode.DICT_LIST,
textwrap.dedent(
"""\
Run a Spanner Read-Only query in the spanner database and return the result.
Args:
project_id (str): The GCP project id in which the spanner database
resides.
instance_id (str): The instance id of the spanner database.
database_id (str): The database id of the spanner database.
query (str): The Spanner SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
settings (SpannerToolSettings): The settings for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
dict: Dictionary with the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) AS count FROM my_table")
{
"status": "SUCCESS",
"rows": [
{
"count": 100
}
]
}
</Example>
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT COUNT(*) FROM my_table")
{
"status": "SUCCESS",
"rows": [
{
"": 100
}
]
}
</Example>
<Example>
>>> execute_sql("my_project", "my_instance", "my_database",
... "SELECT name, rating, description FROM hotels_table")
{
"status": "SUCCESS",
"rows": [
{
"name": "The Hotel",
"rating": 4.1,
"description": "Modern hotel."
},
{
"name": "Park Inn",
"rating": 4.5,
"description": "Cozy hotel."
},
...
]
}
</Example>
Note:
This is running with Read-Only Transaction for query that only read data."""
),
),
],
)
async def test_execute_sql_query_result(
query_result_mode, expected_description
):
"""Test Spanner execute_sql tool query result in different modes."""
tool_name = "execute_sql"
tool_settings = SpannerToolSettings(query_result_mode=query_result_mode)
tool = await get_tool(tool_name, tool_settings)
assert tool.name == tool_name
assert tool.description == expected_description
@mock.patch.object(query_tool.utils, "execute_sql", spec_set=True)
def test_execute_sql(mock_utils_execute_sql):
"""Test execute_sql function in query result default mode."""
mock_credentials = mock.create_autospec(
Credentials, instance=True, spec_set=True
)
mock_tool_context = mock.create_autospec(
ToolContext, instance=True, spec_set=True
)
mock_utils_execute_sql.return_value = {"status": "SUCCESS", "rows": [[1]]}
result = query_tool.execute_sql(
project_id="test-project",
instance_id="test-instance",
database_id="test-database",
query="SELECT 1",
credentials=mock_credentials,
settings=settings.SpannerToolSettings(),
tool_context=mock_tool_context,
)
mock_utils_execute_sql.assert_called_once_with(
"test-project",
"test-instance",
"test-database",
"SELECT 1",
mock_credentials,
settings.SpannerToolSettings(),
mock_tool_context,
)
assert result == {"status": "SUCCESS", "rows": [[1]]}
@@ -14,6 +14,8 @@
from __future__ import annotations
from google.adk.tools.spanner.settings import Capabilities
from google.adk.tools.spanner.settings import QueryResultMode
from google.adk.tools.spanner.settings import SpannerToolSettings
from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
from pydantic import ValidationError
@@ -70,3 +72,26 @@ def test_spanner_vector_store_settings_invalid_vector_length():
assert "Invalid vector length in the Spanner vector store settings." in str(
excinfo.value
)
@pytest.mark.parametrize(
"settings_args, expected_rows, expected_mode",
[
({}, 50, QueryResultMode.DEFAULT),
(
{
"capabilities": [Capabilities.DATA_READ],
"max_executed_query_result_rows": 100,
"query_result_mode": QueryResultMode.DICT_LIST,
},
100,
QueryResultMode.DICT_LIST,
),
],
)
def test_spanner_tool_settings(settings_args, expected_rows, expected_mode):
"""Test SpannerToolSettings with different values."""
settings = SpannerToolSettings(**settings_args)
assert settings.capabilities == [Capabilities.DATA_READ]
assert settings.max_executed_query_result_rows == expected_rows
assert settings.query_result_mode == expected_mode