You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: add SpannerVectorStore for orchestrating and providing utility functions for a Spanner vector store
PiperOrigin-RevId: 854392465
This commit is contained in:
committed by
Copybara-Service
parent
8fb2be216f
commit
59eda98eae
@@ -41,6 +41,7 @@ class FeatureName(str, Enum):
|
||||
PUBSUB_TOOLSET = "PUBSUB_TOOLSET"
|
||||
SPANNER_TOOLSET = "SPANNER_TOOLSET"
|
||||
SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS"
|
||||
SPANNER_VECTOR_STORE = "SPANNER_VECTOR_STORE"
|
||||
TOOL_CONFIG = "TOOL_CONFIG"
|
||||
TOOL_CONFIRMATION = "TOOL_CONFIRMATION"
|
||||
|
||||
@@ -120,6 +121,9 @@ _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
|
||||
FeatureName.SPANNER_TOOL_SETTINGS: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
FeatureName.SPANNER_VECTOR_STORE: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
FeatureName.TOOL_CONFIG: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
|
||||
@@ -55,6 +55,74 @@ class QueryResultMode(Enum):
|
||||
"""
|
||||
|
||||
|
||||
class TableColumn(BaseModel):
|
||||
"""Represents column configuration, to be used as part of create DDL statement for a new vector store table set up."""
|
||||
|
||||
name: str
|
||||
"""Required. The name of the column."""
|
||||
|
||||
type: str
|
||||
"""Required. The type of the column.
|
||||
|
||||
For example,
|
||||
|
||||
- GoogleSQL: 'STRING(MAX)', 'INT64', 'FLOAT64', 'BOOL', etc.
|
||||
- PostgreSQL: 'text', 'int8', 'float8', 'boolean', etc.
|
||||
"""
|
||||
|
||||
is_nullable: bool = True
|
||||
"""Optional. Whether the column is nullable. By default, the column is nullable."""
|
||||
|
||||
|
||||
class VectorSearchIndexSettings(BaseModel):
|
||||
"""Settings for the index for use with Approximate Nearest Neighbor (ANN) vector similarity search."""
|
||||
|
||||
index_name: str
|
||||
"""Required. The name of the vector similarity search index."""
|
||||
|
||||
additional_key_columns: Optional[list[str]] = None
|
||||
"""Optional. The list of the additional key column names in the vector similarity search index.
|
||||
|
||||
To further speed up filtering for highly selective filtering columns, organize
|
||||
them as additional keys in the vector index after the embedding column.
|
||||
For example: `category` as additional key column.
|
||||
`CREATE VECTOR INDEX ON documents(embedding, category);`
|
||||
"""
|
||||
|
||||
additional_storing_columns: Optional[list[str]] = None
|
||||
"""Optional. The list of the storing column names in the vector similarity search index.
|
||||
|
||||
This enables filtering while walking the vector index, removing unqualified
|
||||
rows early.
|
||||
For example: `category` as storing column.
|
||||
`CREATE VECTOR INDEX ON documents(embedding) STORING (category);`
|
||||
"""
|
||||
|
||||
tree_depth: int = 2
|
||||
"""Required. The tree depth (level). This value can be either 2 or 3.
|
||||
|
||||
A tree with 2 levels only has leaves (num_leaves) as nodes.
|
||||
If the dataset has more than 100 million rows,
|
||||
then you can use a tree with 3 levels and add branches (num_branches) to
|
||||
further partition the dataset.
|
||||
"""
|
||||
|
||||
num_leaves: int = 1000
|
||||
"""Required. The number of leaves (i.e. potential partitions) for the vector data.
|
||||
|
||||
You can designate num_leaves for trees with 2 or 3 levels.
|
||||
We recommend that the number of leaves is number_of_rows_in_dataset/1000.
|
||||
"""
|
||||
|
||||
num_branches: Optional[int] = None
|
||||
"""Optional. The number of branches to further parititon the vector data.
|
||||
|
||||
You can only designate num_branches for trees with 3 levels.
|
||||
The number of branches must be fewer than the number of leaves
|
||||
We recommend that the number of leaves is between 1000 and sqrt(number_of_rows_in_dataset).
|
||||
"""
|
||||
|
||||
|
||||
class SpannerVectorStoreSettings(BaseModel):
|
||||
"""Settings for Spanner Vector Store.
|
||||
|
||||
@@ -86,18 +154,19 @@ class SpannerVectorStoreSettings(BaseModel):
|
||||
|
||||
vertex_ai_embedding_model_name: str
|
||||
"""Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search.
|
||||
For example, 'text-embedding-005'.
|
||||
|
||||
Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field.
|
||||
Otherwise, a runtime error might be raised during a query.
|
||||
For example, 'text-embedding-005'.
|
||||
|
||||
Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field.
|
||||
Otherwise, a runtime error might be raised during a query.
|
||||
"""
|
||||
|
||||
selected_columns: List[str] = []
|
||||
selected_columns: list[str] = []
|
||||
"""Required. The vector store table columns to return in the vector similarity search result.
|
||||
|
||||
By default, only the `content_column` value and the distance value are returned.
|
||||
If sepecified, the list of selected columns and the distance value are returned.
|
||||
For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value.
|
||||
By default, only the `content_column` value and the distance value are returned.
|
||||
If sepecified, the list of selected columns and the distance value are returned.
|
||||
For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value.
|
||||
"""
|
||||
|
||||
nearest_neighbors_algorithm: NearestNeighborsAlgorithm = (
|
||||
@@ -105,8 +174,8 @@ class SpannerVectorStoreSettings(BaseModel):
|
||||
)
|
||||
"""The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS.
|
||||
|
||||
For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors
|
||||
For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors
|
||||
For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors
|
||||
For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors
|
||||
"""
|
||||
|
||||
top_k: int = 4
|
||||
@@ -118,16 +187,41 @@ class SpannerVectorStoreSettings(BaseModel):
|
||||
num_leaves_to_search: Optional[int] = None
|
||||
"""Optional. This option specifies how many leaf nodes of the index are searched.
|
||||
|
||||
Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
|
||||
For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices
|
||||
Note: This option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
|
||||
For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices
|
||||
"""
|
||||
|
||||
additional_filter: Optional[str] = None
|
||||
"""Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query."""
|
||||
|
||||
vector_search_index_settings: Optional[VectorSearchIndexSettings] = None
|
||||
"""Optional. Settings for the index for use with Approximate Nearest Neighbor (ANN) in the vector store.
|
||||
|
||||
Note: This option is only required when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS.
|
||||
For more details, see https://docs.cloud.google.com/spanner/docs/vector-indexes
|
||||
"""
|
||||
|
||||
additional_columns_to_setup: Optional[list[TableColumn]] = None
|
||||
"""Optional. A list of supplemental columns to be created when initializing a new vector store table or inserting content rows.
|
||||
|
||||
Note: This configuration is only utilized during the initial table setup
|
||||
or when inserting content rows.
|
||||
"""
|
||||
|
||||
primary_key_columns: Optional[list[str]] = None
|
||||
"""Optional. Specifies the column names to be used as the primary key for a new vector store table.
|
||||
|
||||
If provided, every column name listed here must be defined within
|
||||
`additional_columns_to_setup`. If this field is omitted (set to `None`),
|
||||
defaults to a single primary key column named `id` which automatically
|
||||
generates UUIDs for each entry.
|
||||
|
||||
Note: This field is only used during the creation phase of a new vector store.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def __post_init__(self):
|
||||
"""Validate the embedding settings."""
|
||||
"""Validate the vector store settings."""
|
||||
if not self.vector_length or self.vector_length <= 0:
|
||||
raise ValueError(
|
||||
"Invalid vector length in the Spanner vector store settings."
|
||||
@@ -136,6 +230,17 @@ class SpannerVectorStoreSettings(BaseModel):
|
||||
if not self.selected_columns:
|
||||
self.selected_columns = [self.content_column]
|
||||
|
||||
if self.primary_key_columns:
|
||||
cols = {self.content_column, self.embedding_column}
|
||||
if self.additional_columns_to_setup:
|
||||
cols.update({c.name for c in self.additional_columns_to_setup})
|
||||
|
||||
for pk in self.primary_key_columns:
|
||||
if pk not in cols:
|
||||
raise ValueError(
|
||||
f"Primary key column '{pk}' not found in column definitions."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,384 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.spanner import utils as spanner_utils
|
||||
from google.adk.tools.spanner.settings import SpannerToolSettings
|
||||
from google.adk.tools.spanner.settings import SpannerVectorStoreSettings
|
||||
from google.adk.tools.spanner.settings import TableColumn
|
||||
from google.adk.tools.spanner.settings import VectorSearchIndexSettings
|
||||
from google.cloud.spanner_admin_database_v1.types import DatabaseDialect
|
||||
from google.cloud.spanner_v1 import batch as spanner_batch
|
||||
from google.cloud.spanner_v1 import client as spanner_client_v1
|
||||
from google.cloud.spanner_v1 import database as spanner_database
|
||||
from google.cloud.spanner_v1 import instance as spanner_instance
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_settings():
|
||||
"""Fixture for SpannerVectorStoreSettings."""
|
||||
return SpannerVectorStoreSettings(
|
||||
project_id="test-project",
|
||||
instance_id="test-instance",
|
||||
database_id="test-database",
|
||||
table_name="test_vector_store",
|
||||
content_column="content",
|
||||
embedding_column="embedding",
|
||||
vector_length=768,
|
||||
vertex_ai_embedding_model_name="textembedding",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def spanner_tool_settings(vector_store_settings):
|
||||
"""Fixture for SpannerToolSettings."""
|
||||
return SpannerToolSettings(vector_store_settings=vector_store_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_spanner_database():
|
||||
"""Fixture for a mocked spanner database."""
|
||||
mock_database = mock.create_autospec(spanner_database.Database, instance=True)
|
||||
mock_database.exists.return_value = True
|
||||
mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
|
||||
return mock_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_spanner_instance(mock_spanner_database):
|
||||
"""Fixture for a mocked spanner instance."""
|
||||
mock_instance = mock.create_autospec(spanner_instance.Instance, instance=True)
|
||||
mock_instance.exists.return_value = True
|
||||
mock_instance.database.return_value = mock_spanner_database
|
||||
return mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_spanner_client(mock_spanner_instance):
|
||||
"""Fixture for a mocked spanner client."""
|
||||
mock_client = mock.create_autospec(spanner_client_v1.Client, instance=True)
|
||||
mock_client.instance.return_value = mock_spanner_instance
|
||||
mock_client._client_info = mock.Mock(user_agent="test-agent")
|
||||
return mock_client
|
||||
|
||||
|
||||
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
|
||||
def test_add_contents_successful(
|
||||
mock_embed_contents,
|
||||
spanner_tool_settings,
|
||||
mock_spanner_client,
|
||||
mock_spanner_database,
|
||||
mocker,
|
||||
):
|
||||
"""Test that add_contents successfully adds content."""
|
||||
mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]]
|
||||
mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True)
|
||||
mock_batch.__enter__.return_value = mock_batch
|
||||
mock_spanner_database.batch.return_value = mock_batch
|
||||
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
vector_store._database = mock_spanner_database
|
||||
contents = ["content1", "content2"]
|
||||
vector_store.add_contents(contents=contents)
|
||||
|
||||
mock_spanner_database.reload.assert_called_once()
|
||||
mock_spanner_database.batch.assert_called_once()
|
||||
mock_batch.insert_or_update.assert_called_once_with(
|
||||
table="test_vector_store",
|
||||
columns=["content", "embedding"],
|
||||
values=[
|
||||
["content1", [1.0, 2.0]],
|
||||
["content2", [3.0, 4.0]],
|
||||
],
|
||||
)
|
||||
mock_embed_contents.assert_called_once_with(
|
||||
"textembedding", contents, 768, mock.ANY
|
||||
)
|
||||
|
||||
|
||||
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
|
||||
def test_add_contents_with_metadata(
|
||||
mock_embed_contents,
|
||||
spanner_tool_settings,
|
||||
mock_spanner_client,
|
||||
mock_spanner_database,
|
||||
mocker,
|
||||
):
|
||||
"""Test that add_contents successfully adds content with metadata."""
|
||||
mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]]
|
||||
mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True)
|
||||
mock_batch.__enter__.return_value = mock_batch
|
||||
mock_spanner_database.batch.return_value = mock_batch
|
||||
spanner_tool_settings.vector_store_settings.additional_columns_to_setup = [
|
||||
TableColumn(name="metadata", type="JSON")
|
||||
]
|
||||
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
vector_store._database = mock_spanner_database
|
||||
contents = ["content1", "content2"]
|
||||
additional_columns_values = [
|
||||
{"metadata": {"meta1": "val1"}},
|
||||
{"metadata": {"meta2": "val2"}},
|
||||
]
|
||||
vector_store.add_contents(
|
||||
contents=contents,
|
||||
additional_columns_values=additional_columns_values,
|
||||
)
|
||||
|
||||
mock_spanner_database.batch.assert_called_once()
|
||||
mock_batch.insert_or_update.assert_called_once_with(
|
||||
table="test_vector_store",
|
||||
columns=["content", "embedding", "metadata"],
|
||||
values=[
|
||||
["content1", [1.0, 2.0], {"meta1": "val1"}],
|
||||
["content2", [3.0, 4.0], {"meta2": "val2"}],
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_add_contents_empty_contents(
|
||||
spanner_tool_settings, mock_spanner_client, mock_spanner_database
|
||||
):
|
||||
"""Test that add_contents does nothing when contents is empty."""
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
vector_store.add_contents(contents=[])
|
||||
mock_spanner_database.batch.assert_not_called()
|
||||
|
||||
|
||||
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
|
||||
def test_add_contents_additional_columns_list_mismatch(
|
||||
mock_embed_contents, spanner_tool_settings, mock_spanner_client
|
||||
):
|
||||
"""Test that add_contents raises an error if additional_columns_values and contents lengths differ."""
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="additional_columns_values contains more items than contents.",
|
||||
):
|
||||
vector_store.add_contents(
|
||||
contents=["content1"],
|
||||
additional_columns_values=[
|
||||
{"col1": "val1"},
|
||||
{"col1": "val2"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@mock.patch.object(spanner_utils, "embed_contents", autospec=True)
|
||||
def test_add_contents_embedding_fails(
|
||||
mock_embed_contents, spanner_tool_settings, mock_spanner_client
|
||||
):
|
||||
"""Test that add_contents fails if embedding fails."""
|
||||
mock_embed_contents.side_effect = RuntimeError("Embedding failed")
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
with pytest.raises(RuntimeError, match="Embedding failed"):
|
||||
vector_store.add_contents(contents=["content1", "content2"])
|
||||
|
||||
|
||||
def test_init_raises_error_if_vector_store_settings_not_set():
|
||||
"""Test that SpannerVectorStore raises an error if vector_store_settings is not set."""
|
||||
settings = SpannerToolSettings()
|
||||
with pytest.raises(
|
||||
ValueError, match="Spanner vector store settings are not set."
|
||||
):
|
||||
spanner_utils.SpannerVectorStore(settings)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dialect, expected_ddl",
|
||||
[
|
||||
(
|
||||
DatabaseDialect.GOOGLE_STANDARD_SQL,
|
||||
(
|
||||
"CREATE TABLE IF NOT EXISTS test_vector_store (\n"
|
||||
" id STRING(36) DEFAULT (GENERATE_UUID()),\n"
|
||||
" content STRING(MAX),\n"
|
||||
" embedding ARRAY<FLOAT32>(vector_length=>768)\n"
|
||||
") PRIMARY KEY(id)"
|
||||
),
|
||||
),
|
||||
(
|
||||
DatabaseDialect.POSTGRESQL,
|
||||
(
|
||||
"CREATE TABLE IF NOT EXISTS test_vector_store (\n"
|
||||
" id varchar(36) DEFAULT spanner.generate_uuid(),\n"
|
||||
" content text,\n"
|
||||
" embedding float4[] VECTOR LENGTH 768,\n"
|
||||
" PRIMARY KEY(id)\n"
|
||||
")"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_create_vector_store_table_ddl(
|
||||
spanner_tool_settings, mock_spanner_client, dialect, expected_ddl
|
||||
):
|
||||
"""Test DDL creation for different SQL dialects."""
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
ddl = vector_store._create_vector_store_table_ddl(dialect)
|
||||
assert ddl == expected_ddl
|
||||
|
||||
|
||||
def test_create_ann_vector_search_index_ddl_raises_error_for_postgresql(
|
||||
spanner_tool_settings, vector_store_settings, mock_spanner_client
|
||||
):
|
||||
"""Test that creating an ANN index raises an error for PostgreSQL."""
|
||||
vector_store_settings.vector_search_index_settings = mock.Mock()
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="ANN is only supported for the Google Standard SQL dialect.",
|
||||
):
|
||||
vector_store._create_ann_vector_search_index_ddl(
|
||||
DatabaseDialect.POSTGRESQL
|
||||
)
|
||||
|
||||
|
||||
def test_create_vector_store(
|
||||
spanner_tool_settings, mock_spanner_client, mock_spanner_database
|
||||
):
|
||||
"""Test the vector store creation process."""
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
vector_store.create_vector_store()
|
||||
mock_spanner_database.update_ddl.assert_called_once()
|
||||
ddl_statement = mock_spanner_database.update_ddl.call_args[0][0]
|
||||
assert "CREATE TABLE IF NOT EXISTS test_vector_store" in ddl_statement[0]
|
||||
|
||||
|
||||
def test_create_vector_search_index_no_settings(
|
||||
spanner_tool_settings, mock_spanner_client, mock_spanner_database
|
||||
):
|
||||
"""Test that create_vector_search_index does nothing if settings are not present."""
|
||||
spanner_tool_settings.vector_store_settings.vector_search_index_settings = (
|
||||
None
|
||||
)
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
vector_store.create_vector_search_index()
|
||||
mock_spanner_database.update_ddl.assert_not_called()
|
||||
|
||||
|
||||
def test_create_vector_search_index_successful_google_sql(
|
||||
spanner_tool_settings,
|
||||
vector_store_settings,
|
||||
mock_spanner_client,
|
||||
mock_spanner_database,
|
||||
):
|
||||
"""Test that create_vector_search_index successfully creates index for Google SQL."""
|
||||
mock_spanner_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL
|
||||
vector_store_settings.vector_search_index_settings = (
|
||||
VectorSearchIndexSettings(
|
||||
index_name="test_vector_index",
|
||||
tree_depth=3,
|
||||
num_branches=10,
|
||||
num_leaves=20,
|
||||
)
|
||||
)
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
vector_store.create_vector_search_index()
|
||||
mock_spanner_database.update_ddl.assert_called_once()
|
||||
ddl_statement = mock_spanner_database.update_ddl.call_args[0][0]
|
||||
expected_ddl = (
|
||||
"CREATE VECTOR INDEX IF NOT EXISTS test_vector_index\n"
|
||||
"\tON test_vector_store(embedding)\n"
|
||||
"\tWHERE embedding IS NOT NULL\n"
|
||||
"\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=10, "
|
||||
"num_leaves=20)"
|
||||
)
|
||||
assert ddl_statement[0] == expected_ddl
|
||||
|
||||
|
||||
def test_create_vector_search_index_fails(
|
||||
spanner_tool_settings,
|
||||
vector_store_settings,
|
||||
mock_spanner_client,
|
||||
mock_spanner_database,
|
||||
):
|
||||
"""Test that create_vector_search_index raises an error if DDL execution fails."""
|
||||
mock_spanner_database.update_ddl.side_effect = RuntimeError("DDL failed")
|
||||
vector_store_settings.vector_search_index_settings = (
|
||||
VectorSearchIndexSettings(index_name="test_vector_index")
|
||||
)
|
||||
with mock.patch.object(
|
||||
spanner_utils.client,
|
||||
"get_spanner_client",
|
||||
autospec=True,
|
||||
return_value=mock_spanner_client,
|
||||
):
|
||||
vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings)
|
||||
with pytest.raises(RuntimeError, match="DDL failed"):
|
||||
vector_store.create_vector_search_index()
|
||||
Reference in New Issue
Block a user