feat: Bigquery ADK support for search catalog tool

Merge https://github.com/google/adk-python/pull/4171

**Problem:**
The BigQuery ADK tools currently lack the ability to search for and discover BigQuery assets using the Dataplex Catalog. Users cannot leverage Dataplex's search capabilities within the ADK to find relevant data assets before querying them.

**Solution:**
This PR integrates a new search_catalog_tool into the BigQuery ADK. This tool utilizes the dataplex catalog client library to interact with the Dataplex API, allowing users to search the catalog.

**Unit Tests:**

- [x] I have added or updated unit tests for my change.
- [x] All unit tests pass locally.

Added the screenshots of the manual adk web UI tests - https://docs.google.com/document/d/1c_lMW7NYGKuLAvPFmSkLehbqySeNyXQIhzQlvo3ixmQ/edit?usp=sharing

### Checklist

- [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document.
- [x] I have performed a self-review of my own code.
- [x] I have commented my code, particularly in hard-to-understand areas.
- [x] I have added tests that prove my fix is effective or that my feature works.
- [x] New and existing unit tests pass locally with my changes.
- [x] I have manually tested my changes end-to-end.
- [x] Any dependent changes have been merged and published in downstream modules.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4171 from sahaajaaa:sahaajaaa-bq-adk 3dbbaa4f909cb25259e8e7d73a00a58fbe9c2f09
PiperOrigin-RevId: 872951141
This commit is contained in:
Sahaja Reddy Pabbathi Reddy
2026-02-20 09:55:04 -08:00
committed by Copybara-Service
parent a39ca946d6
commit bef3f117b4
10 changed files with 768 additions and 13 deletions
+4
View File
@@ -55,6 +55,9 @@ distributed via the `google.adk.tools.bigquery` module. These tools include:
`ARIMA_PLUS` model and then querying it with
`ML.DETECT_ANOMALIES` to detect time series data anomalies.
11. `search_catalog`
Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets.
## How to use
Set up environment variables in your `.env` file for using
@@ -159,3 +162,4 @@ the necessary access tokens to call BigQuery APIs on their behalf.
* which tables exist in the ml_datasets dataset?
* show more details about the penguins table
* compute penguins population per island.
* are there any tables related to animals in project <your_project_id>?
+1
View File
@@ -37,6 +37,7 @@ dependencies = [
"google-cloud-bigquery-storage>=2.0.0",
"google-cloud-bigquery>=2.2.0",
"google-cloud-bigtable>=2.32.0", # For Bigtable database
"google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
"google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
@@ -19,6 +19,10 @@ from ...features import FeatureName
from .._google_credentials import BaseGoogleCredentialsConfig
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
BIGQUERY_SCOPES = [
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/dataplex",
]
BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"]
@@ -34,8 +38,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig):
super().__post_init__()
if not self.scopes:
self.scopes = BIGQUERY_DEFAULT_SCOPE
self.scopes = BIGQUERY_SCOPES
# Set the token cache key
self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY
return self
@@ -24,6 +24,7 @@ from typing_extensions import override
from . import data_insights_tool
from . import metadata_tool
from . import query_tool
from . import search_tool
from ...features import experimental
from ...features import FeatureName
from ...tools.base_tool import BaseTool
@@ -87,6 +88,7 @@ class BigQueryToolset(BaseToolset):
query_tool.analyze_contribution,
query_tool.detect_anomalies,
data_insights_tool.ask_data_insights,
search_tool.search_catalog,
]
]
+39 -6
View File
@@ -14,19 +14,22 @@
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
import google.api_core.client_info
from google.api_core.gapic_v1 import client_info as gapic_client_info
from google.auth.credentials import Credentials
from google.cloud import bigquery
from google.cloud import dataplex_v1
from ... import version
USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}"
from typing import List
from typing import Union
USER_AGENT_BASE = f"google-adk/{version.__version__}"
BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}"
DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}"
USER_AGENT = BQ_USER_AGENT
def get_bigquery_client(
@@ -48,7 +51,7 @@ def get_bigquery_client(
A BigQuery client.
"""
user_agents = [USER_AGENT]
user_agents = [BQ_USER_AGENT]
if user_agent:
if isinstance(user_agent, str):
user_agents.append(user_agent)
@@ -67,3 +70,33 @@ def get_bigquery_client(
)
return bigquery_client
def get_dataplex_catalog_client(
*,
credentials: Credentials,
user_agent: Optional[Union[str, List[str]]] = None,
) -> dataplex_v1.CatalogServiceClient:
"""Get a Dataplex CatalogServiceClient with minimal necessary arguments.
Args:
credentials: The credentials to use for the request.
user_agent: Additional user agent string(s) to append.
Returns:
A Dataplex Client.
"""
user_agents = [DP_USER_AGENT]
if user_agent:
if isinstance(user_agent, str):
user_agents.append(user_agent)
else:
user_agents.extend([ua for ua in user_agent if ua])
client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents))
return dataplex_v1.CatalogServiceClient(
credentials=credentials,
client_info=client_info,
)
@@ -0,0 +1,179 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import Any
from google.api_core import exceptions as api_exceptions
from google.auth.credentials import Credentials
from google.cloud import dataplex_v1
from . import client
from .config import BigQueryToolConfig
def _construct_search_query_helper(
predicate: str, operator: str, items: list[str]
) -> str:
"""Constructs a search query part for a specific predicate and items."""
if not items:
return ""
clauses = [f'{predicate}{operator}"{item}"' for item in items]
return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0]
def search_catalog(
prompt: str,
project_id: str,
*,
credentials: Credentials,
settings: BigQueryToolConfig,
location: str | None = None,
page_size: int = 10,
project_ids_filter: list[str] | None = None,
dataset_ids_filter: list[str] | None = None,
types_filter: list[str] | None = None,
) -> dict[str, Any]:
"""Searches for BigQuery assets within Dataplex.
Args:
prompt: The base search query (natural language or keywords).
project_id: The Google Cloud project ID to scope the search.
credentials: Credentials for the request.
settings: BigQuery tool settings.
location: The Dataplex location to use.
page_size: Maximum number of results.
project_ids_filter: Specific project IDs to include in the search results.
If None, defaults to the scoping project_id.
dataset_ids_filter: BigQuery dataset IDs to filter by.
types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE,
BigQueryEntryType.DATASET).
Returns:
Search results or error. The "results" list contains items with:
- name: The Dataplex Entry name (e.g.,
"projects/p/locations/l/entryGroups/g/entries/e").
- linked_resource: The underlying BigQuery resource name (e.g.,
"//bigquery.googleapis.com/projects/p/datasets/d/tables/t").
- display_name, entry_type, description, location, update_time.
Examples:
Search for tables related to customer data:
>>> search_catalog(
... prompt="Search for tables related to customer data",
... project_id="my-project",
... credentials=creds,
... settings=settings
... )
{
"status": "SUCCESS",
"results": [
{
"name":
"projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id",
"display_name": "customer_table",
"entry_type":
"projects/p/locations/l/entryTypes/bigquery-table",
"linked_resource":
"//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table",
"description": "Table containing customer details.",
"location": "us",
"update_time": "2024-01-01 12:00:00+00:00"
}
]
}
"""
try:
if not project_id:
return {
"status": "ERROR",
"error_details": "project_id must be provided.",
}
with client.get_dataplex_catalog_client(
credentials=credentials,
user_agent=[settings.application_name, "search_catalog"],
) as dataplex_client:
query_parts = []
if prompt:
query_parts.append(f"({prompt})")
# Filter by project IDs
projects_to_filter = (
project_ids_filter if project_ids_filter else [project_id]
)
if projects_to_filter:
query_parts.append(
_construct_search_query_helper("projectid", "=", projects_to_filter)
)
# Filter by dataset IDs
if dataset_ids_filter:
dataset_resource_filters = []
for pid in projects_to_filter:
for did in dataset_ids_filter:
dataset_resource_filters.append(
f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"'
)
if dataset_resource_filters:
query_parts.append(f"({' OR '.join(dataset_resource_filters)})")
# Filter by entry types
if types_filter:
query_parts.append(
_construct_search_query_helper("type", "=", types_filter)
)
# Always scope to BigQuery system
query_parts.append("system=BIGQUERY")
full_query = " AND ".join(filter(None, query_parts))
search_location = location or settings.location or "global"
search_scope = f"projects/{project_id}/locations/{search_location}"
request = dataplex_v1.SearchEntriesRequest(
name=search_scope,
query=full_query,
page_size=page_size,
semantic_search=True,
)
response = dataplex_client.search_entries(request=request)
results = []
for result in response.results:
entry = result.dataplex_entry
source = entry.entry_source
results.append({
"name": entry.name,
"display_name": source.display_name or "",
"entry_type": entry.entry_type,
"update_time": str(entry.update_time),
"linked_resource": source.resource or "",
"description": source.description or "",
"location": source.location or "",
})
return {"status": "SUCCESS", "results": results}
except api_exceptions.GoogleAPICallError as e:
logging.exception("search_catalog tool: API call failed")
return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"}
except Exception as e:
logging.exception("search_catalog tool: Unexpected error")
return {"status": "ERROR", "error_details": repr(e)}
@@ -18,9 +18,13 @@ import os
from unittest import mock
import google.adk
from google.adk.tools.bigquery.client import DP_USER_AGENT
from google.adk.tools.bigquery.client import get_bigquery_client
from google.adk.tools.bigquery.client import get_dataplex_catalog_client
from google.api_core.gapic_v1 import client_info as gapic_client_info
import google.auth
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import dataplex_v1
from google.cloud.bigquery import client as bigquery_client
from google.oauth2.credentials import Credentials
@@ -201,3 +205,74 @@ def test_bigquery_client_location_custom():
# Verify that the client has the desired project set
assert client.project == "test-gcp-project"
assert client.location == "us-central1"
# Tests for Dataplex Catalog Client
# ------------------------------------------------------------------------------
# Mock the CatalogServiceClient class directly
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
def test_dataplex_client_default(mock_catalog_service_client):
"""Test get_dataplex_catalog_client with default user agent."""
mock_creds = mock.create_autospec(Credentials, instance=True)
client = get_dataplex_catalog_client(credentials=mock_creds)
mock_catalog_service_client.assert_called_once()
_, kwargs = mock_catalog_service_client.call_args
assert kwargs["credentials"] == mock_creds
client_info = kwargs["client_info"]
assert isinstance(client_info, gapic_client_info.ClientInfo)
assert client_info.user_agent == DP_USER_AGENT
# Ensure the function returns the mock instance
assert client == mock_catalog_service_client.return_value
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client):
"""Test get_dataplex_catalog_client with a custom user agent string."""
mock_creds = mock.create_autospec(Credentials, instance=True)
custom_ua = "catalog_ua/1.0"
expected_ua = f"{DP_USER_AGENT} {custom_ua}"
get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua)
mock_catalog_service_client.assert_called_once()
_, kwargs = mock_catalog_service_client.call_args
client_info = kwargs["client_info"]
assert client_info.user_agent == expected_ua
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client):
"""Test get_dataplex_catalog_client with a custom user agent list."""
mock_creds = mock.create_autospec(Credentials, instance=True)
custom_ua_list = ["catalog_ua", "catalog_ua_2.0"]
expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}"
get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list)
mock_catalog_service_client.assert_called_once()
_, kwargs = mock_catalog_service_client.call_args
client_info = kwargs["client_info"]
assert client_info.user_agent == expected_ua
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
def test_dataplex_client_custom_user_agent_list_with_none(
mock_catalog_service_client,
):
"""Test get_dataplex_catalog_client with a list containing None."""
mock_creds = mock.create_autospec(Credentials, instance=True)
custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"]
expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0"
get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list)
mock_catalog_service_client.assert_called_once()
_, kwargs = mock_catalog_service_client.call_args
client_info = kwargs["client_info"]
assert client_info.user_agent == expected_ua
@@ -44,9 +44,11 @@ class TestBigQueryCredentials:
# Verify that the credentials are properly stored and attributes are extracted
assert config.credentials == auth_creds
assert config.client_id is None
assert config.client_secret is None
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
assert config.scopes == [
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/dataplex",
]
def test_valid_credentials_object_oauth2_credentials(self):
"""Test that providing valid Credentials object works correctly with
@@ -86,7 +88,10 @@ class TestBigQueryCredentials:
assert config.credentials is None
assert config.client_id == "test_client_id"
assert config.client_secret == "test_client_secret"
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
assert config.scopes == [
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/dataplex",
]
def test_valid_client_id_secret_pair_w_scope(self):
"""Test that providing client ID and secret with explicit scopes works.
@@ -128,7 +133,10 @@ class TestBigQueryCredentials:
assert config.credentials is None
assert config.client_id == "test_client_id"
assert config.client_secret == "test_client_secret"
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
assert config.scopes == [
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/dataplex",
]
def test_missing_client_secret_raises_error(self):
"""Test that missing client secret raises appropriate validation error.
@@ -0,0 +1,448 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import sys
from typing import Any
import unittest
from unittest import mock
from absl.testing import parameterized
# Mock google.genai and pydantic if not available, before importing google.adk modules
try:
import google.genai
except ImportError:
m = mock.MagicMock()
m.__path__ = []
sys.modules["google.genai"] = m
sys.modules["google.genai.types"] = mock.MagicMock()
sys.modules["google.genai.errors"] = mock.MagicMock()
try:
import pydantic
except ImportError:
m_pydantic = mock.MagicMock()
class MockBaseModel:
pass
m_pydantic.BaseModel = MockBaseModel
sys.modules["pydantic"] = m_pydantic
try:
import fastapi
import fastapi.openapi.models
except ImportError:
m_fastapi = mock.MagicMock()
m_fastapi.openapi.models = mock.MagicMock()
sys.modules["fastapi"] = m_fastapi
sys.modules["fastapi.openapi"] = mock.MagicMock()
sys.modules["fastapi.openapi.models"] = mock.MagicMock()
from google.adk.tools.bigquery import search_tool
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.api_core import exceptions as api_exceptions
from google.auth.credentials import Credentials
from google.cloud import dataplex_v1
def _mock_creds():
return mock.create_autospec(Credentials, instance=True)
def _mock_settings(app_name: str | None = "test-app"):
return BigQueryToolConfig(application_name=app_name)
def _mock_search_entries_response(results: list[dict[str, Any]]):
mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse)
mock_results = []
for r in results:
mock_result = mock.create_autospec(
dataplex_v1.SearchEntriesResult, instance=True
)
# Manually attach dataplex_entry since it's not visible in dir() of the proto class
mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True)
mock_result.dataplex_entry = mock_entry
mock_entry.name = r.get("name")
mock_entry.entry_type = r.get("entry_type")
mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z")
# Manually attach entry_source since it's not visible in dir() of the proto class
mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True)
mock_entry.entry_source = mock_source
mock_source.display_name = r.get("display_name")
mock_source.resource = r.get("linked_resource")
mock_source.description = r.get("description")
mock_source.location = r.get("location")
mock_results.append(mock_result)
mock_response.results = mock_results
return mock_response
class TestSearchCatalog(parameterized.TestCase):
def setUp(self):
super().setUp()
self.mock_dataplex_client = mock.create_autospec(
dataplex_v1.CatalogServiceClient, instance=True
)
# Patch get_dataplex_catalog_client
self.mock_get_dataplex_client = self.enter_context(
mock.patch(
"google.adk.tools.bigquery.client.get_dataplex_catalog_client",
autospec=True,
)
)
self.mock_get_dataplex_client.return_value = self.mock_dataplex_client
self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client
# Patch SearchEntriesRequest
self.mock_search_request = self.enter_context(
mock.patch(
"google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True
)
)
def test_search_catalog_success(self):
"""Test the successful path of search_catalog."""
creds = _mock_creds()
settings = _mock_settings()
prompt = "customer data"
project_id = "test-project"
location = "us"
mock_api_results = [{
"name": "entry1",
"entry_type": "TABLE",
"display_name": "Cust Table",
"linked_resource": (
"//bigquery.googleapis.com/projects/p/datasets/d/tables/t1"
),
"description": "Table 1",
"location": "us",
}]
self.mock_dataplex_client.search_entries.return_value = (
_mock_search_entries_response(mock_api_results)
)
result = search_tool.search_catalog(
prompt=prompt,
project_id=project_id,
credentials=creds,
settings=settings,
location=location,
)
with self.subTest("Test result content"):
self.assertEqual(result["status"], "SUCCESS")
self.assertLen(result["results"], 1)
self.assertEqual(result["results"][0]["name"], "entry1")
self.assertEqual(result["results"][0]["display_name"], "Cust Table")
with self.subTest("Test mock calls"):
self.mock_get_dataplex_client.assert_called_once_with(
credentials=creds, user_agent=["test-app", "search_catalog"]
)
expected_query = (
'(customer data) AND projectid="test-project" AND system=BIGQUERY'
)
self.mock_search_request.assert_called_once_with(
name=f"projects/{project_id}/locations/us",
query=expected_query,
page_size=10,
semantic_search=True,
)
self.mock_dataplex_client.search_entries.assert_called_once_with(
request=self.mock_search_request.return_value
)
def test_search_catalog_no_project_id(self):
"""Test search_catalog with missing project_id."""
result = search_tool.search_catalog(
prompt="test",
project_id="",
credentials=_mock_creds(),
settings=_mock_settings(),
location="us",
)
self.assertEqual(result["status"], "ERROR")
self.assertIn("project_id must be provided", result["error_details"])
self.mock_get_dataplex_client.assert_not_called()
def test_search_catalog_api_error(self):
"""Test search_catalog handling API exceptions."""
self.mock_dataplex_client.search_entries.side_effect = (
api_exceptions.BadRequest("Invalid query")
)
result = search_tool.search_catalog(
prompt="test",
project_id="test-project",
credentials=_mock_creds(),
settings=_mock_settings(),
location="us",
)
self.assertEqual(result["status"], "ERROR")
self.assertIn(
"Dataplex API Error: 400 Invalid query", result["error_details"]
)
def test_search_catalog_other_exception(self):
"""Test search_catalog handling unexpected exceptions."""
self.mock_get_dataplex_client.side_effect = Exception(
"Something went wrong"
)
result = search_tool.search_catalog(
prompt="test",
project_id="test-project",
credentials=_mock_creds(),
settings=_mock_settings(),
location="us",
)
self.assertEqual(result["status"], "ERROR")
self.assertIn("Something went wrong", result["error_details"])
@parameterized.named_parameters(
("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'),
(
"multi_project_filter",
"p",
["p1", "p2"],
None,
None,
'(projectid="p1" OR projectid="p2")',
),
("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'),
(
"multi_type_filter",
"p",
None,
None,
["TABLE", "DATASET"],
'(type="TABLE" OR type="DATASET")',
),
(
"project_and_dataset_filters",
"inventory",
["proj1", "proj2"],
["dsetA"],
None,
(
'(projectid="proj1" OR projectid="proj2") AND'
' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"'
' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")'
),
),
)
def test_search_catalog_query_construction(
self, prompt, project_ids, dataset_ids, types, expected_query_part
):
"""Test different query constructions based on filters."""
search_tool.search_catalog(
prompt=prompt,
project_id="test-project",
credentials=_mock_creds(),
settings=_mock_settings(),
location="us",
project_ids_filter=project_ids,
dataset_ids_filter=dataset_ids,
types_filter=types,
)
self.mock_search_request.assert_called_once()
_, kwargs = self.mock_search_request.call_args
query = kwargs["query"]
if prompt:
assert f"({prompt})" in query
assert "system=BIGQUERY" in query
assert expected_query_part in query
def test_search_catalog_no_app_name(self):
"""Test search_catalog when settings.application_name is None."""
creds = _mock_creds()
settings = _mock_settings(app_name=None)
search_tool.search_catalog(
prompt="test",
project_id="test-project",
credentials=creds,
settings=settings,
location="us",
)
self.mock_get_dataplex_client.assert_called_once_with(
credentials=creds, user_agent=[None, "search_catalog"]
)
def test_search_catalog_multi_project_filter_semantic(self):
"""Test semantic search with a multi-project filter."""
creds = _mock_creds()
settings = _mock_settings()
prompt = "What datasets store user profiles?"
project_id = "main-project"
project_filters = ["user-data-proj", "shared-infra-proj"]
location = "global"
self.mock_dataplex_client.search_entries.return_value = (
_mock_search_entries_response([])
)
search_tool.search_catalog(
prompt=prompt,
project_id=project_id,
credentials=creds,
settings=settings,
location=location,
project_ids_filter=project_filters,
types_filter=["DATASET"],
)
expected_query = (
f"({prompt}) AND "
'(projectid="user-data-proj" OR projectid="shared-infra-proj") AND '
'type="DATASET" AND system=BIGQUERY'
)
self.mock_search_request.assert_called_once_with(
name=f"projects/{project_id}/locations/{location}",
query=expected_query,
page_size=10,
semantic_search=True,
)
self.mock_dataplex_client.search_entries.assert_called_once()
def test_search_catalog_natural_language_semantic(self):
"""Test natural language prompts with semantic search enabled and check output."""
creds = _mock_creds()
settings = _mock_settings()
prompt = "Find tables about football matches"
project_id = "sports-analytics"
location = "europe-west1"
# Mock the results that the API would return for this semantic query
mock_api_results = [
{
"name": (
"projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1"
),
"display_name": "uk_football_premiership",
"entry_type": (
"projects/655216118709/locations/global/entryTypes/bigquery-table"
),
"linked_resource": (
"//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership"
),
"description": "Stats for UK Premier League matches.",
"location": "europe-west1",
},
{
"name": (
"projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2"
),
"display_name": "serie_a_matches",
"entry_type": (
"projects/655216118709/locations/global/entryTypes/bigquery-table"
),
"linked_resource": (
"//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a"
),
"description": "Italian Serie A football results.",
"location": "europe-west1",
},
]
self.mock_dataplex_client.search_entries.return_value = (
_mock_search_entries_response(mock_api_results)
)
result = search_tool.search_catalog(
prompt=prompt,
project_id=project_id,
credentials=creds,
settings=settings,
location=location,
)
with self.subTest("Query Construction"):
# Assert the request was made as expected
expected_query = (
f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY'
)
self.mock_search_request.assert_called_once_with(
name=f"projects/{project_id}/locations/{location}",
query=expected_query,
page_size=10,
semantic_search=True,
)
self.mock_dataplex_client.search_entries.assert_called_once()
with self.subTest("Response Processing"):
# Assert the output is processed correctly
self.assertEqual(result["status"], "SUCCESS")
self.assertLen(result["results"], 2)
self.assertEqual(
result["results"][0]["display_name"], "uk_football_premiership"
)
self.assertEqual(result["results"][1]["display_name"], "serie_a_matches")
self.assertIn("UK Premier League", result["results"][0]["description"])
def test_search_catalog_default_location(self):
"""Test search_catalog fallback to global location when None is provided."""
creds = _mock_creds()
settings = _mock_settings()
# settings.location is None by default
self.mock_dataplex_client.search_entries.return_value = (
_mock_search_entries_response([])
)
search_tool.search_catalog(
prompt="test",
project_id="test-project",
credentials=creds,
settings=settings,
)
self.mock_search_request.assert_called_once()
_, kwargs = self.mock_search_request.call_args
name_arg = kwargs["name"]
self.assertIn("locations/global", name_arg)
def test_search_catalog_settings_location(self):
"""Test search_catalog uses settings.location when provided."""
creds = _mock_creds()
settings = BigQueryToolConfig(location="eu")
self.mock_dataplex_client.search_entries.return_value = (
_mock_search_entries_response([])
)
search_tool.search_catalog(
prompt="test",
project_id="test-project",
credentials=creds,
settings=settings,
)
self.mock_search_request.assert_called_once()
_, kwargs = self.mock_search_request.call_args
name_arg = kwargs["name"]
self.assertIn("locations/eu", name_arg)
@@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default():
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 10
assert len(tools) == 11
assert all([isinstance(tool, GoogleTool) for tool in tools])
expected_tool_names = set([
@@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default():
"forecast",
"analyze_contribution",
"detect_anomalies",
"search_catalog",
])
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names