You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: add Spanner execute sql query result mode
Add using the execute sql query return result as list of dictionaries. In each dictionary the key is the column name and the value is the value of the that column in a given row. PiperOrigin-RevId: 840909555
This commit is contained in:
committed by
Copybara-Service
parent
de841a4a09
commit
f22bac0b20
@@ -18,6 +18,7 @@ from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.google_tool import GoogleTool
|
||||
from google.adk.tools.spanner.settings import Capabilities
|
||||
from google.adk.tools.spanner.settings import QueryResultMode
|
||||
from google.adk.tools.spanner.settings import SpannerToolSettings
|
||||
from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig
|
||||
from google.adk.tools.spanner.spanner_toolset import SpannerToolset
|
||||
@@ -34,7 +35,10 @@ CREDENTIALS_TYPE = None
|
||||
|
||||
|
||||
# Define Spanner tool config with read capability set to allowed.
|
||||
tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ])
|
||||
tool_settings = SpannerToolSettings(
|
||||
capabilities=[Capabilities.DATA_READ],
|
||||
query_result_mode=QueryResultMode.DICT_LIST,
|
||||
)
|
||||
|
||||
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
|
||||
# Initialize the tools to do interactive OAuth
|
||||
|
||||
@@ -14,10 +14,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import textwrap
|
||||
import types
|
||||
from typing import Callable
|
||||
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
from . import utils
|
||||
from ..tool_context import ToolContext
|
||||
from .settings import QueryResultMode
|
||||
from .settings import SpannerToolSettings
|
||||
|
||||
|
||||
@@ -49,16 +55,29 @@ def execute_sql(
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
Fetch data or insights from a table:
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) AS count FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
[100]
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) AS count FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
[100]
|
||||
]
|
||||
}
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT name, rating, description FROM hotels_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
["The Hotel", 4.1, "Modern hotel."],
|
||||
["Park Inn", 4.5, "Cozy hotel."],
|
||||
...
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
Note:
|
||||
This is running with Read-Only Transaction for query that only read data.
|
||||
@@ -72,3 +91,105 @@ def execute_sql(
|
||||
settings,
|
||||
tool_context,
|
||||
)
|
||||
|
||||
|
||||
_EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING = textwrap.dedent("""\
|
||||
Run a Spanner Read-Only query in the spanner database and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the spanner database
|
||||
resides.
|
||||
instance_id (str): The instance id of the spanner database.
|
||||
database_id (str): The database id of the spanner database.
|
||||
query (str): The Spanner SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
settings (SpannerToolSettings): The settings for the tool.
|
||||
tool_context (ToolContext): The context for the tool.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with the result of the query.
|
||||
If the result contains the key "result_is_likely_truncated" with
|
||||
value True, it means that there may be additional rows matching the
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) AS count FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"count": 100
|
||||
}
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"": 100
|
||||
}
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT name, rating, description FROM hotels_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"name": "The Hotel",
|
||||
"rating": 4.1,
|
||||
"description": "Modern hotel."
|
||||
},
|
||||
{
|
||||
"name": "Park Inn",
|
||||
"rating": 4.5,
|
||||
"description": "Cozy hotel."
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
Note:
|
||||
This is running with Read-Only Transaction for query that only read data.
|
||||
""")
|
||||
|
||||
|
||||
def get_execute_sql(settings: SpannerToolSettings) -> Callable[..., dict]:
|
||||
"""Get the execute_sql tool customized as per the given tool settings.
|
||||
|
||||
Args:
|
||||
settings: Spanner tool settings indicating the behavior of the execute_sql
|
||||
tool.
|
||||
|
||||
Returns:
|
||||
callable[..., dict]: A version of the execute_sql tool respecting the tool
|
||||
settings.
|
||||
"""
|
||||
|
||||
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
|
||||
|
||||
execute_sql_wrapper = types.FunctionType(
|
||||
execute_sql.__code__,
|
||||
execute_sql.__globals__,
|
||||
execute_sql.__name__,
|
||||
execute_sql.__defaults__,
|
||||
execute_sql.__closure__,
|
||||
)
|
||||
functools.update_wrapper(execute_sql_wrapper, execute_sql)
|
||||
# Update with the new docstring
|
||||
execute_sql_wrapper.__doc__ = _EXECUTE_SQL_DICT_LIST_MODE_DOCSTRING
|
||||
return execute_sql_wrapper
|
||||
|
||||
# Return the default execute_sql function.
|
||||
return execute_sql
|
||||
|
||||
@@ -40,6 +40,20 @@ class Capabilities(Enum):
|
||||
"""Read only data operations tools are allowed."""
|
||||
|
||||
|
||||
class QueryResultMode(Enum):
|
||||
"""Settings for Spanner execute sql query result."""
|
||||
|
||||
DEFAULT = "default"
|
||||
"""Return the result of a query as a list of rows data."""
|
||||
|
||||
DICT_LIST = "dict_list"
|
||||
"""Return the result of a query as a list of dictionaries.
|
||||
|
||||
In each dictionary the key is the column name and the value is the value of
|
||||
the that column in a given row.
|
||||
"""
|
||||
|
||||
|
||||
class SpannerVectorStoreSettings(BaseModel):
|
||||
"""Settings for Spanner Vector Store.
|
||||
|
||||
@@ -140,5 +154,8 @@ class SpannerToolSettings(BaseModel):
|
||||
max_executed_query_result_rows: int = 50
|
||||
"""Maximum number of rows to return from a query result."""
|
||||
|
||||
query_result_mode: QueryResultMode = QueryResultMode.DEFAULT
|
||||
"""Mode for Spanner execute sql query result."""
|
||||
|
||||
vector_store_settings: Optional[SpannerVectorStoreSettings] = None
|
||||
"""Settings for Spanner vector store and vector similarity search."""
|
||||
|
||||
@@ -111,7 +111,7 @@ class SpannerToolset(BaseToolset):
|
||||
):
|
||||
all_tools.append(
|
||||
GoogleTool(
|
||||
func=query_tool.execute_sql,
|
||||
func=query_tool.get_execute_sql(self._tool_settings),
|
||||
credentials_config=self._credentials_config,
|
||||
tool_settings=self._tool_settings,
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
|
||||
|
||||
from . import client
|
||||
from ..tool_context import ToolContext
|
||||
from .settings import QueryResultMode
|
||||
from .settings import SpannerToolSettings
|
||||
|
||||
DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50
|
||||
@@ -84,6 +85,9 @@ def execute_sql(
|
||||
if settings and settings.max_executed_query_result_rows > 0
|
||||
else DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS
|
||||
)
|
||||
if settings and settings.query_result_mode is QueryResultMode.DICT_LIST:
|
||||
result_set = result_set.to_dict_list()
|
||||
|
||||
for row in result_set:
|
||||
try:
|
||||
# if the json serialization of the row succeeds, use it as is
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.spanner import query_tool
|
||||
from google.adk.tools.spanner import settings
|
||||
from google.adk.tools.spanner.settings import QueryResultMode
|
||||
from google.adk.tools.spanner.settings import SpannerToolSettings
|
||||
from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig
|
||||
from google.adk.tools.spanner.spanner_toolset import SpannerToolset
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.auth.credentials import Credentials
|
||||
import pytest
|
||||
|
||||
|
||||
async def get_tool(
|
||||
name: str, tool_settings: SpannerToolSettings | None = None
|
||||
) -> BaseTool:
|
||||
"""Get a tool from Spanner toolset."""
|
||||
credentials_config = SpannerCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
|
||||
toolset = SpannerToolset(
|
||||
credentials_config=credentials_config,
|
||||
tool_filter=[name],
|
||||
spanner_tool_settings=tool_settings,
|
||||
)
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
assert len(tools) == 1
|
||||
return tools[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"query_result_mode, expected_description",
|
||||
[
|
||||
(
|
||||
QueryResultMode.DEFAULT,
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
Run a Spanner Read-Only query in the spanner database and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the spanner database
|
||||
resides.
|
||||
instance_id (str): The instance id of the spanner database.
|
||||
database_id (str): The database id of the spanner database.
|
||||
query (str): The Spanner SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
settings (SpannerToolSettings): The settings for the tool.
|
||||
tool_context (ToolContext): The context for the tool.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with the result of the query.
|
||||
If the result contains the key "result_is_likely_truncated" with
|
||||
value True, it means that there may be additional rows matching the
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) AS count FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
[100]
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT name, rating, description FROM hotels_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
["The Hotel", 4.1, "Modern hotel."],
|
||||
["Park Inn", 4.5, "Cozy hotel."],
|
||||
...
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
Note:
|
||||
This is running with Read-Only Transaction for query that only read data."""
|
||||
),
|
||||
),
|
||||
(
|
||||
QueryResultMode.DICT_LIST,
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
Run a Spanner Read-Only query in the spanner database and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the spanner database
|
||||
resides.
|
||||
instance_id (str): The instance id of the spanner database.
|
||||
database_id (str): The database id of the spanner database.
|
||||
query (str): The Spanner SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
settings (SpannerToolSettings): The settings for the tool.
|
||||
tool_context (ToolContext): The context for the tool.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with the result of the query.
|
||||
If the result contains the key "result_is_likely_truncated" with
|
||||
value True, it means that there may be additional rows matching the
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) AS count FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"count": 100
|
||||
}
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT COUNT(*) FROM my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"": 100
|
||||
}
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
<Example>
|
||||
>>> execute_sql("my_project", "my_instance", "my_database",
|
||||
... "SELECT name, rating, description FROM hotels_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"name": "The Hotel",
|
||||
"rating": 4.1,
|
||||
"description": "Modern hotel."
|
||||
},
|
||||
{
|
||||
"name": "Park Inn",
|
||||
"rating": 4.5,
|
||||
"description": "Cozy hotel."
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
</Example>
|
||||
|
||||
Note:
|
||||
This is running with Read-Only Transaction for query that only read data."""
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_execute_sql_query_result(
|
||||
query_result_mode, expected_description
|
||||
):
|
||||
"""Test Spanner execute_sql tool query result in different modes."""
|
||||
tool_name = "execute_sql"
|
||||
tool_settings = SpannerToolSettings(query_result_mode=query_result_mode)
|
||||
tool = await get_tool(tool_name, tool_settings)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == expected_description
|
||||
|
||||
|
||||
@mock.patch.object(query_tool.utils, "execute_sql", spec_set=True)
|
||||
def test_execute_sql(mock_utils_execute_sql):
|
||||
"""Test execute_sql function in query result default mode."""
|
||||
mock_credentials = mock.create_autospec(
|
||||
Credentials, instance=True, spec_set=True
|
||||
)
|
||||
mock_tool_context = mock.create_autospec(
|
||||
ToolContext, instance=True, spec_set=True
|
||||
)
|
||||
mock_utils_execute_sql.return_value = {"status": "SUCCESS", "rows": [[1]]}
|
||||
|
||||
result = query_tool.execute_sql(
|
||||
project_id="test-project",
|
||||
instance_id="test-instance",
|
||||
database_id="test-database",
|
||||
query="SELECT 1",
|
||||
credentials=mock_credentials,
|
||||
settings=settings.SpannerToolSettings(),
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
mock_utils_execute_sql.assert_called_once_with(
|
||||
"test-project",
|
||||
"test-instance",
|
||||
"test-database",
|
||||
"SELECT 1",
|
||||
mock_credentials,
|
||||
settings.SpannerToolSettings(),
|
||||
mock_tool_context,
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "rows": [[1]]}
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.tools.spanner.settings import Capabilities
|
||||
from google.adk.tools.spanner.settings import QueryResultMode
|
||||
from google.adk.tools.spanner.settings import SpannerToolSettings
|
||||
from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
|
||||
from pydantic import ValidationError
|
||||
@@ -70,3 +72,26 @@ def test_spanner_vector_store_settings_invalid_vector_length():
|
||||
assert "Invalid vector length in the Spanner vector store settings." in str(
|
||||
excinfo.value
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"settings_args, expected_rows, expected_mode",
|
||||
[
|
||||
({}, 50, QueryResultMode.DEFAULT),
|
||||
(
|
||||
{
|
||||
"capabilities": [Capabilities.DATA_READ],
|
||||
"max_executed_query_result_rows": 100,
|
||||
"query_result_mode": QueryResultMode.DICT_LIST,
|
||||
},
|
||||
100,
|
||||
QueryResultMode.DICT_LIST,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_spanner_tool_settings(settings_args, expected_rows, expected_mode):
|
||||
"""Test SpannerToolSettings with different values."""
|
||||
settings = SpannerToolSettings(**settings_args)
|
||||
assert settings.capabilities == [Capabilities.DATA_READ]
|
||||
assert settings.max_executed_query_result_rows == expected_rows
|
||||
assert settings.query_result_mode == expected_mode
|
||||
|
||||
Reference in New Issue
Block a user