feat: add SpannerVectorStore for orchestrating and providing utility functions for a Spanner vector store

PiperOrigin-RevId: 854392465
This commit is contained in:
Google Team Member
2026-01-09 16:57:36 -08:00
committed by Copybara-Service
parent 8fb2be216f
commit 59eda98eae
4 changed files with 1121 additions and 14 deletions
@@ -41,6 +41,7 @@ class FeatureName(str, Enum):
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
SPANNER_TOOLSET = "SPANNER_TOOLSET"
SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS"
SPANNER_VECTOR_STORE = "SPANNER_VECTOR_STORE"
TOOL_CONFIG = "TOOL_CONFIG"
TOOL_CONFIRMATION = "TOOL_CONFIRMATION"
@@ -120,6 +121,9 @@ _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
FeatureName.SPANNER_TOOL_SETTINGS: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
FeatureName.SPANNER_VECTOR_STORE: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
FeatureName.TOOL_CONFIG: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
+117 -12
View File
@@ -55,6 +55,74 @@ class QueryResultMode(Enum):
"""
class TableColumn(BaseModel):
"""Represents column configuration, to be used as part of create DDL statement for a new vector store table set up."""
name: str
"""Required. The name of the column."""
type: str
"""Required. The type of the column.
For example,
- GoogleSQL: 'STRING(MAX)', 'INT64', 'FLOAT64', 'BOOL', etc.
- PostgreSQL: 'text', 'int8', 'float8', 'boolean', etc.
"""
is_nullable: bool = True
"""Optional. Whether the column is nullable. By default, the column is nullable."""
class VectorSearchIndexSettings(BaseModel):
"""Settings for the index for use with Approximate Nearest Neighbor (ANN) vector similarity search."""
index_name: str
"""Required. The name of the vector similarity search index."""
additional_key_columns: Optional[list[str]] = None
"""Optional. The list of the additional key column names in the vector similarity search index.
To further speed up filtering for highly selective filtering columns, organize
them as additional keys in the vector index after the embedding column.
For example: `category` as additional key column.
`CREATE VECTOR INDEX ON documents(embedding, category);`
"""
additional_storing_columns: Optional[list[str]] = None
"""Optional. The list of the storing column names in the vector similarity search index.
This enables filtering while walking the vector index, removing unqualified
rows early.
For example: `category` as storing column.
`CREATE VECTOR INDEX ON documents(embedding) STORING (category);`
"""
tree_depth: int = 2
"""Required. The tree depth (level). This value can be either 2 or 3.
A tree with 2 levels only has leaves (num_leaves) as nodes.
If the dataset has more than 100 million rows,
then you can use a tree with 3 levels and add branches (num_branches) to
further partition the dataset.
"""
num_leaves: int = 1000
"""Required. The number of leaves (i.e. potential partitions) for the vector data.
You can designate num_leaves for trees with 2 or 3 levels.
We recommend that the number of leaves is number_of_rows_in_dataset/1000.
"""
num_branches: Optional[int] = None
"""Optional. The number of branches to further parititon the vector data.
You can only designate num_branches for trees with 3 levels.
The number of branches must be fewer than the number of leaves
We recommend that the number of leaves is between 1000 and sqrt(number_of_rows_in_dataset).
"""
class SpannerVectorStoreSettings(BaseModel):
"""Settings for Spanner Vector Store.
@@ -86,18 +154,19 @@ class SpannerVectorStoreSettings(BaseModel):
vertex_ai_embedding_model_name: str
"""Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search.
For example, 'text-embedding-005'.
Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field.
Otherwise, a runtime error might be raised during a query.
For example, 'text-embedding-005'.
Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field.
Otherwise, a runtime error might be raised during a query.
"""
selected_columns: List[str] = []
selected_columns: list[str] = []
"""Required. The vector store table columns to return in the vector similarity search result.
By default, only the `content_column` value and the distance value are returned.
If sepecified, the list of selected columns and the distance value are returned.
For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value.
By default, only the `content_column` value and the distance value are returned.
If sepecified, the list of selected columns and the distance value are returned.
For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value.
"""
nearest_neighbors_algorithm: NearestNeighborsAlgorithm = (
@@ -105,8 +174,8 @@ class SpannerVectorStoreSettings(BaseModel):
)
"""The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS.
For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors
For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors
For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors
For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors
"""
top_k: int = 4
@@ -118,16 +187,41 @@ class SpannerVectorStoreSettings(BaseModel):
num_leaves_to_search: Optional[int] = None
"""Optional. This option specifies how many leaf nodes of the index are searched.
Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices
Note: This option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices
"""
additional_filter: Optional[str] = None
"""Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query."""
vector_search_index_settings: Optional[VectorSearchIndexSettings] = None
"""Optional. Settings for the index for use with Approximate Nearest Neighbor (ANN) in the vector store.
Note: This option is only required when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
For more details, see https://docs.cloud.google.com/spanner/docs/vector-indexes
"""
additional_columns_to_setup: Optional[list[TableColumn]] = None
"""Optional. A list of supplemental columns to be created when initializing a new vector store table or inserting content rows.
Note: This configuration is only utilized during the initial table setup
or when inserting content rows.
"""
primary_key_columns: Optional[list[str]] = None
"""Optional. Specifies the column names to be used as the primary key for a new vector store table.
If provided, every column name listed here must be defined within
`additional_columns_to_setup`. If this field is omitted (set to `None`),
defaults to a single primary key column named `id` which automatically
generates UUIDs for each entry.
Note: This field is only used during the creation phase of a new vector store.
"""
@model_validator(mode="after")
def __post_init__(self):
"""Validate the embedding settings."""
"""Validate the vector store settings."""
if not self.vector_length or self.vector_length <= 0:
raise ValueError(
"Invalid vector length in the Spanner vector store settings."
@@ -136,6 +230,17 @@ class SpannerVectorStoreSettings(BaseModel):
if not self.selected_columns:
self.selected_columns = [self.content_column]
if self.primary_key_columns:
cols = {self.content_column, self.embedding_column}
if self.additional_columns_to_setup:
cols.update({c.name for c in self.additional_columns_to_setup})
for pk in self.primary_key_columns:
if pk not in cols:
raise ValueError(
f"Primary key column '{pk}' not found in column definitions."
)
return self
File diff suppressed because it is too large Load Diff
+384
View File
@@ -0,0 +1,384 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from unittest import mock
from google.adk.tools.spanner import utils as spanner_utils
from google.adk.tools.spanner.settings import SpannerToolSettings
from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
from google.adk.tools.spanner.settings import TableColumn
from google.adk.tools.spanner.settings import VectorSearchIndexSettings
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
from google.cloud.spanner_v1 import batch as spanner_batch
from google.cloud.spanner_v1 import client as spanner_client_v1
from google.cloud.spanner_v1 import database as spanner_database
from google.cloud.spanner_v1 import instance as spanner_instance
import pytest
@pytest.fixture
def vector_store_settings():
"""Fixture for SpannerVectorStoreSettings."""
return SpannerVectorStoreSettings(
project_id="test-project",
instance_id="test-instance",
database_id="test-database",
table_name="test_vector_store",
content_column="content",
embedding_column="embedding",
vector_length=768,
vertex_ai_embedding_model_name="textembedding",
)
@pytest.fixture
def spanner_tool_settings(vector_store_settings):
"""Fixture for SpannerToolSettings."""
return SpannerToolSettings(vector_store_settings=vector_store_settings)
@pytest.fixture
def mock_spanner_database():
"""Fixture for a mocked spanner database."""
mock_database = mock.create_autospec(spanner_database.Database, instance=True)
mock_database.exists.return_value = True
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
return mock_database
@pytest.fixture
def mock_spanner_instance(mock_spanner_database):
"""Fixture for a mocked spanner instance."""
mock_instance = mock.create_autospec(spanner_instance.Instance, instance=True)
mock_instance.exists.return_value = True
mock_instance.database.return_value = mock_spanner_database
return mock_instance
@pytest.fixture
def mock_spanner_client(mock_spanner_instance):
"""Fixture for a mocked spanner client."""
mock_client = mock.create_autospec(spanner_client_v1.Client, instance=True)
mock_client.instance.return_value = mock_spanner_instance
mock_client._client_info = mock.Mock(user_agent="test-agent")
return mock_client
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
def test_add_contents_successful(
mock_embed_contents,
spanner_tool_settings,
mock_spanner_client,
mock_spanner_database,
mocker,
):
"""Test that add_contents successfully adds content."""
mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]]
mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True)
mock_batch.__enter__.return_value = mock_batch
mock_spanner_database.batch.return_value = mock_batch
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
vector_store._database = mock_spanner_database
contents = ["content1", "content2"]
vector_store.add_contents(contents=contents)
mock_spanner_database.reload.assert_called_once()
mock_spanner_database.batch.assert_called_once()
mock_batch.insert_or_update.assert_called_once_with(
table="test_vector_store",
columns=["content", "embedding"],
values=[
["content1", [1.0, 2.0]],
["content2", [3.0, 4.0]],
],
)
mock_embed_contents.assert_called_once_with(
"textembedding", contents, 768, mock.ANY
)
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
def test_add_contents_with_metadata(
mock_embed_contents,
spanner_tool_settings,
mock_spanner_client,
mock_spanner_database,
mocker,
):
"""Test that add_contents successfully adds content with metadata."""
mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]]
mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True)
mock_batch.__enter__.return_value = mock_batch
mock_spanner_database.batch.return_value = mock_batch
spanner_tool_settings.vector_store_settings.additional_columns_to_setup = [
TableColumn(name="metadata", type="JSON")
]
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
vector_store._database = mock_spanner_database
contents = ["content1", "content2"]
additional_columns_values = [
{"metadata": {"meta1": "val1"}},
{"metadata": {"meta2": "val2"}},
]
vector_store.add_contents(
contents=contents,
additional_columns_values=additional_columns_values,
)
mock_spanner_database.batch.assert_called_once()
mock_batch.insert_or_update.assert_called_once_with(
table="test_vector_store",
columns=["content", "embedding", "metadata"],
values=[
["content1", [1.0, 2.0], {"meta1": "val1"}],
["content2", [3.0, 4.0], {"meta2": "val2"}],
],
)
def test_add_contents_empty_contents(
spanner_tool_settings, mock_spanner_client, mock_spanner_database
):
"""Test that add_contents does nothing when contents is empty."""
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
vector_store.add_contents(contents=[])
mock_spanner_database.batch.assert_not_called()
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
def test_add_contents_additional_columns_list_mismatch(
mock_embed_contents, spanner_tool_settings, mock_spanner_client
):
"""Test that add_contents raises an error if additional_columns_values and contents lengths differ."""
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
with pytest.raises(
ValueError,
match="additional_columns_values contains more items than contents.",
):
vector_store.add_contents(
contents=["content1"],
additional_columns_values=[
{"col1": "val1"},
{"col1": "val2"},
],
)
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
def test_add_contents_embedding_fails(
mock_embed_contents, spanner_tool_settings, mock_spanner_client
):
"""Test that add_contents fails if embedding fails."""
mock_embed_contents.side_effect = RuntimeError("Embedding failed")
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
with pytest.raises(RuntimeError, match="Embedding failed"):
vector_store.add_contents(contents=["content1", "content2"])
def test_init_raises_error_if_vector_store_settings_not_set():
"""Test that SpannerVectorStore raises an error if vector_store_settings is not set."""
settings = SpannerToolSettings()
with pytest.raises(
ValueError, match="Spanner vector store settings are not set."
):
spanner_utils.SpannerVectorStore(settings)
@pytest.mark.parametrize(
"dialect, expected_ddl",
[
(
DatabaseDialect.GOOGLE_STANDARD_SQL,
(
"CREATE TABLE IF NOT EXISTS test_vector_store (\n"
" id STRING(36) DEFAULT (GENERATE_UUID()),\n"
" content STRING(MAX),\n"
" embedding ARRAY<FLOAT32>(vector_length=>768)\n"
") PRIMARY KEY(id)"
),
),
(
DatabaseDialect.POSTGRESQL,
(
"CREATE TABLE IF NOT EXISTS test_vector_store (\n"
" id varchar(36) DEFAULT spanner.generate_uuid(),\n"
" content text,\n"
" embedding float4[] VECTOR LENGTH 768,\n"
" PRIMARY KEY(id)\n"
")"
),
),
],
)
def test_create_vector_store_table_ddl(
spanner_tool_settings, mock_spanner_client, dialect, expected_ddl
):
"""Test DDL creation for different SQL dialects."""
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
ddl = vector_store._create_vector_store_table_ddl(dialect)
assert ddl == expected_ddl
def test_create_ann_vector_search_index_ddl_raises_error_for_postgresql(
spanner_tool_settings, vector_store_settings, mock_spanner_client
):
"""Test that creating an ANN index raises an error for PostgreSQL."""
vector_store_settings.vector_search_index_settings = mock.Mock()
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
with pytest.raises(
ValueError,
match="ANN is only supported for the Google Standard SQL dialect.",
):
vector_store._create_ann_vector_search_index_ddl(
DatabaseDialect.POSTGRESQL
)
def test_create_vector_store(
spanner_tool_settings, mock_spanner_client, mock_spanner_database
):
"""Test the vector store creation process."""
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
vector_store.create_vector_store()
mock_spanner_database.update_ddl.assert_called_once()
ddl_statement = mock_spanner_database.update_ddl.call_args[0][0]
assert "CREATE TABLE IF NOT EXISTS test_vector_store" in ddl_statement[0]
def test_create_vector_search_index_no_settings(
spanner_tool_settings, mock_spanner_client, mock_spanner_database
):
"""Test that create_vector_search_index does nothing if settings are not present."""
spanner_tool_settings.vector_store_settings.vector_search_index_settings = (
None
)
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
vector_store.create_vector_search_index()
mock_spanner_database.update_ddl.assert_not_called()
def test_create_vector_search_index_successful_google_sql(
spanner_tool_settings,
vector_store_settings,
mock_spanner_client,
mock_spanner_database,
):
"""Test that create_vector_search_index successfully creates index for Google SQL."""
mock_spanner_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
vector_store_settings.vector_search_index_settings = (
VectorSearchIndexSettings(
index_name="test_vector_index",
tree_depth=3,
num_branches=10,
num_leaves=20,
)
)
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
vector_store.create_vector_search_index()
mock_spanner_database.update_ddl.assert_called_once()
ddl_statement = mock_spanner_database.update_ddl.call_args[0][0]
expected_ddl = (
"CREATE VECTOR INDEX IF NOT EXISTS test_vector_index\n"
"\tON test_vector_store(embedding)\n"
"\tWHERE embedding IS NOT NULL\n"
"\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=10, "
"num_leaves=20)"
)
assert ddl_statement[0] == expected_ddl
def test_create_vector_search_index_fails(
spanner_tool_settings,
vector_store_settings,
mock_spanner_client,
mock_spanner_database,
):
"""Test that create_vector_search_index raises an error if DDL execution fails."""
mock_spanner_database.update_ddl.side_effect = RuntimeError("DDL failed")
vector_store_settings.vector_search_index_settings = (
VectorSearchIndexSettings(index_name="test_vector_index")
)
with mock.patch.object(
spanner_utils.client,
"get_spanner_client",
autospec=True,
return_value=mock_spanner_client,
):
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
with pytest.raises(RuntimeError, match="DDL failed"):
vector_store.create_vector_search_index()