You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Bigquery ADK support for search catalog tool
Merge https://github.com/google/adk-python/pull/4171 **Problem:** The BigQuery ADK tools currently lack the ability to search for and discover BigQuery assets using the Dataplex Catalog. Users cannot leverage Dataplex's search capabilities within the ADK to find relevant data assets before querying them. **Solution:** This PR integrates a new search_catalog_tool into the BigQuery ADK. This tool utilizes the dataplex catalog client library to interact with the Dataplex API, allowing users to search the catalog. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. Added the screenshots of the manual adk web UI tests - https://docs.google.com/document/d/1c_lMW7NYGKuLAvPFmSkLehbqySeNyXQIhzQlvo3ixmQ/edit?usp=sharing ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4171 from sahaajaaa:sahaajaaa-bq-adk 3dbbaa4f909cb25259e8e7d73a00a58fbe9c2f09 PiperOrigin-RevId: 872951141
This commit is contained in:
committed by
Copybara-Service
parent
a39ca946d6
commit
bef3f117b4
@@ -55,6 +55,9 @@ distributed via the `google.adk.tools.bigquery` module. These tools include:
|
||||
`ARIMA_PLUS` model and then querying it with
|
||||
`ML.DETECT_ANOMALIES` to detect time series data anomalies.
|
||||
|
||||
11. `search_catalog`
|
||||
Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets.
|
||||
|
||||
## How to use
|
||||
|
||||
Set up environment variables in your `.env` file for using
|
||||
@@ -159,3 +162,4 @@ the necessary access tokens to call BigQuery APIs on their behalf.
|
||||
* which tables exist in the ml_datasets dataset?
|
||||
* show more details about the penguins table
|
||||
* compute penguins population per island.
|
||||
* are there any tables related to animals in project <your_project_id>?
|
||||
|
||||
@@ -37,6 +37,7 @@ dependencies = [
|
||||
"google-cloud-bigquery-storage>=2.0.0",
|
||||
"google-cloud-bigquery>=2.2.0",
|
||||
"google-cloud-bigtable>=2.32.0", # For Bigtable database
|
||||
"google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool
|
||||
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
|
||||
"google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool
|
||||
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
|
||||
|
||||
@@ -19,6 +19,10 @@ from ...features import FeatureName
|
||||
from .._google_credentials import BaseGoogleCredentialsConfig
|
||||
|
||||
BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache"
|
||||
BIGQUERY_SCOPES = [
|
||||
"https://www.googleapis.com/auth/bigquery",
|
||||
"https://www.googleapis.com/auth/dataplex",
|
||||
]
|
||||
BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"]
|
||||
|
||||
|
||||
@@ -34,8 +38,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig):
|
||||
super().__post_init__()
|
||||
|
||||
if not self.scopes:
|
||||
self.scopes = BIGQUERY_DEFAULT_SCOPE
|
||||
|
||||
self.scopes = BIGQUERY_SCOPES
|
||||
# Set the token cache key
|
||||
self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY
|
||||
|
||||
return self
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing_extensions import override
|
||||
from . import data_insights_tool
|
||||
from . import metadata_tool
|
||||
from . import query_tool
|
||||
from . import search_tool
|
||||
from ...features import experimental
|
||||
from ...features import FeatureName
|
||||
from ...tools.base_tool import BaseTool
|
||||
@@ -87,6 +88,7 @@ class BigQueryToolset(BaseToolset):
|
||||
query_tool.analyze_contribution,
|
||||
query_tool.detect_anomalies,
|
||||
data_insights_tool.ask_data_insights,
|
||||
search_tool.search_catalog,
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@@ -14,19 +14,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import google.api_core.client_info
|
||||
from google.api_core.gapic_v1 import client_info as gapic_client_info
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
from google.cloud import dataplex_v1
|
||||
|
||||
from ... import version
|
||||
|
||||
USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}"
|
||||
|
||||
|
||||
from typing import List
|
||||
from typing import Union
|
||||
USER_AGENT_BASE = f"google-adk/{version.__version__}"
|
||||
BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}"
|
||||
DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}"
|
||||
USER_AGENT = BQ_USER_AGENT
|
||||
|
||||
|
||||
def get_bigquery_client(
|
||||
@@ -48,7 +51,7 @@ def get_bigquery_client(
|
||||
A BigQuery client.
|
||||
"""
|
||||
|
||||
user_agents = [USER_AGENT]
|
||||
user_agents = [BQ_USER_AGENT]
|
||||
if user_agent:
|
||||
if isinstance(user_agent, str):
|
||||
user_agents.append(user_agent)
|
||||
@@ -67,3 +70,33 @@ def get_bigquery_client(
|
||||
)
|
||||
|
||||
return bigquery_client
|
||||
|
||||
|
||||
def get_dataplex_catalog_client(
|
||||
*,
|
||||
credentials: Credentials,
|
||||
user_agent: Optional[Union[str, List[str]]] = None,
|
||||
) -> dataplex_v1.CatalogServiceClient:
|
||||
"""Get a Dataplex CatalogServiceClient with minimal necessary arguments.
|
||||
|
||||
Args:
|
||||
credentials: The credentials to use for the request.
|
||||
user_agent: Additional user agent string(s) to append.
|
||||
|
||||
Returns:
|
||||
A Dataplex Client.
|
||||
"""
|
||||
|
||||
user_agents = [DP_USER_AGENT]
|
||||
if user_agent:
|
||||
if isinstance(user_agent, str):
|
||||
user_agents.append(user_agent)
|
||||
else:
|
||||
user_agents.extend([ua for ua in user_agent if ua])
|
||||
|
||||
client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents))
|
||||
|
||||
return dataplex_v1.CatalogServiceClient(
|
||||
credentials=credentials,
|
||||
client_info=client_info,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from google.api_core import exceptions as api_exceptions
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import dataplex_v1
|
||||
|
||||
from . import client
|
||||
from .config import BigQueryToolConfig
|
||||
|
||||
|
||||
def _construct_search_query_helper(
|
||||
predicate: str, operator: str, items: list[str]
|
||||
) -> str:
|
||||
"""Constructs a search query part for a specific predicate and items."""
|
||||
if not items:
|
||||
return ""
|
||||
|
||||
clauses = [f'{predicate}{operator}"{item}"' for item in items]
|
||||
return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0]
|
||||
|
||||
|
||||
def search_catalog(
|
||||
prompt: str,
|
||||
project_id: str,
|
||||
*,
|
||||
credentials: Credentials,
|
||||
settings: BigQueryToolConfig,
|
||||
location: str | None = None,
|
||||
page_size: int = 10,
|
||||
project_ids_filter: list[str] | None = None,
|
||||
dataset_ids_filter: list[str] | None = None,
|
||||
types_filter: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Searches for BigQuery assets within Dataplex.
|
||||
|
||||
Args:
|
||||
prompt: The base search query (natural language or keywords).
|
||||
project_id: The Google Cloud project ID to scope the search.
|
||||
credentials: Credentials for the request.
|
||||
settings: BigQuery tool settings.
|
||||
location: The Dataplex location to use.
|
||||
page_size: Maximum number of results.
|
||||
project_ids_filter: Specific project IDs to include in the search results.
|
||||
If None, defaults to the scoping project_id.
|
||||
dataset_ids_filter: BigQuery dataset IDs to filter by.
|
||||
types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE,
|
||||
BigQueryEntryType.DATASET).
|
||||
|
||||
Returns:
|
||||
Search results or error. The "results" list contains items with:
|
||||
- name: The Dataplex Entry name (e.g.,
|
||||
"projects/p/locations/l/entryGroups/g/entries/e").
|
||||
- linked_resource: The underlying BigQuery resource name (e.g.,
|
||||
"//bigquery.googleapis.com/projects/p/datasets/d/tables/t").
|
||||
- display_name, entry_type, description, location, update_time.
|
||||
|
||||
Examples:
|
||||
Search for tables related to customer data:
|
||||
|
||||
>>> search_catalog(
|
||||
... prompt="Search for tables related to customer data",
|
||||
... project_id="my-project",
|
||||
... credentials=creds,
|
||||
... settings=settings
|
||||
... )
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"results": [
|
||||
{
|
||||
"name":
|
||||
"projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id",
|
||||
"display_name": "customer_table",
|
||||
"entry_type":
|
||||
"projects/p/locations/l/entryTypes/bigquery-table",
|
||||
"linked_resource":
|
||||
"//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table",
|
||||
"description": "Table containing customer details.",
|
||||
"location": "us",
|
||||
"update_time": "2024-01-01 12:00:00+00:00"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
try:
|
||||
if not project_id:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": "project_id must be provided.",
|
||||
}
|
||||
|
||||
with client.get_dataplex_catalog_client(
|
||||
credentials=credentials,
|
||||
user_agent=[settings.application_name, "search_catalog"],
|
||||
) as dataplex_client:
|
||||
query_parts = []
|
||||
if prompt:
|
||||
query_parts.append(f"({prompt})")
|
||||
|
||||
# Filter by project IDs
|
||||
projects_to_filter = (
|
||||
project_ids_filter if project_ids_filter else [project_id]
|
||||
)
|
||||
if projects_to_filter:
|
||||
query_parts.append(
|
||||
_construct_search_query_helper("projectid", "=", projects_to_filter)
|
||||
)
|
||||
|
||||
# Filter by dataset IDs
|
||||
if dataset_ids_filter:
|
||||
dataset_resource_filters = []
|
||||
for pid in projects_to_filter:
|
||||
for did in dataset_ids_filter:
|
||||
dataset_resource_filters.append(
|
||||
f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"'
|
||||
)
|
||||
if dataset_resource_filters:
|
||||
query_parts.append(f"({' OR '.join(dataset_resource_filters)})")
|
||||
# Filter by entry types
|
||||
if types_filter:
|
||||
query_parts.append(
|
||||
_construct_search_query_helper("type", "=", types_filter)
|
||||
)
|
||||
|
||||
# Always scope to BigQuery system
|
||||
query_parts.append("system=BIGQUERY")
|
||||
|
||||
full_query = " AND ".join(filter(None, query_parts))
|
||||
|
||||
search_location = location or settings.location or "global"
|
||||
search_scope = f"projects/{project_id}/locations/{search_location}"
|
||||
|
||||
request = dataplex_v1.SearchEntriesRequest(
|
||||
name=search_scope,
|
||||
query=full_query,
|
||||
page_size=page_size,
|
||||
semantic_search=True,
|
||||
)
|
||||
|
||||
response = dataplex_client.search_entries(request=request)
|
||||
|
||||
results = []
|
||||
for result in response.results:
|
||||
entry = result.dataplex_entry
|
||||
source = entry.entry_source
|
||||
results.append({
|
||||
"name": entry.name,
|
||||
"display_name": source.display_name or "",
|
||||
"entry_type": entry.entry_type,
|
||||
"update_time": str(entry.update_time),
|
||||
"linked_resource": source.resource or "",
|
||||
"description": source.description or "",
|
||||
"location": source.location or "",
|
||||
})
|
||||
return {"status": "SUCCESS", "results": results}
|
||||
|
||||
except api_exceptions.GoogleAPICallError as e:
|
||||
logging.exception("search_catalog tool: API call failed")
|
||||
return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"}
|
||||
except Exception as e:
|
||||
logging.exception("search_catalog tool: Unexpected error")
|
||||
return {"status": "ERROR", "error_details": repr(e)}
|
||||
@@ -18,9 +18,13 @@ import os
|
||||
from unittest import mock
|
||||
|
||||
import google.adk
|
||||
from google.adk.tools.bigquery.client import DP_USER_AGENT
|
||||
from google.adk.tools.bigquery.client import get_bigquery_client
|
||||
from google.adk.tools.bigquery.client import get_dataplex_catalog_client
|
||||
from google.api_core.gapic_v1 import client_info as gapic_client_info
|
||||
import google.auth
|
||||
from google.auth.exceptions import DefaultCredentialsError
|
||||
from google.cloud import dataplex_v1
|
||||
from google.cloud.bigquery import client as bigquery_client
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
@@ -201,3 +205,74 @@ def test_bigquery_client_location_custom():
|
||||
# Verify that the client has the desired project set
|
||||
assert client.project == "test-gcp-project"
|
||||
assert client.location == "us-central1"
|
||||
|
||||
|
||||
# Tests for Dataplex Catalog Client
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Mock the CatalogServiceClient class directly
|
||||
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
|
||||
def test_dataplex_client_default(mock_catalog_service_client):
|
||||
"""Test get_dataplex_catalog_client with default user agent."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True)
|
||||
|
||||
client = get_dataplex_catalog_client(credentials=mock_creds)
|
||||
|
||||
mock_catalog_service_client.assert_called_once()
|
||||
_, kwargs = mock_catalog_service_client.call_args
|
||||
|
||||
assert kwargs["credentials"] == mock_creds
|
||||
client_info = kwargs["client_info"]
|
||||
assert isinstance(client_info, gapic_client_info.ClientInfo)
|
||||
assert client_info.user_agent == DP_USER_AGENT
|
||||
|
||||
# Ensure the function returns the mock instance
|
||||
assert client == mock_catalog_service_client.return_value
|
||||
|
||||
|
||||
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
|
||||
def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client):
|
||||
"""Test get_dataplex_catalog_client with a custom user agent string."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True)
|
||||
custom_ua = "catalog_ua/1.0"
|
||||
expected_ua = f"{DP_USER_AGENT} {custom_ua}"
|
||||
|
||||
get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua)
|
||||
|
||||
mock_catalog_service_client.assert_called_once()
|
||||
_, kwargs = mock_catalog_service_client.call_args
|
||||
client_info = kwargs["client_info"]
|
||||
assert client_info.user_agent == expected_ua
|
||||
|
||||
|
||||
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
|
||||
def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client):
|
||||
"""Test get_dataplex_catalog_client with a custom user agent list."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True)
|
||||
custom_ua_list = ["catalog_ua", "catalog_ua_2.0"]
|
||||
expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}"
|
||||
|
||||
get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list)
|
||||
|
||||
mock_catalog_service_client.assert_called_once()
|
||||
_, kwargs = mock_catalog_service_client.call_args
|
||||
client_info = kwargs["client_info"]
|
||||
assert client_info.user_agent == expected_ua
|
||||
|
||||
|
||||
@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True)
|
||||
def test_dataplex_client_custom_user_agent_list_with_none(
|
||||
mock_catalog_service_client,
|
||||
):
|
||||
"""Test get_dataplex_catalog_client with a list containing None."""
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True)
|
||||
custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"]
|
||||
expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0"
|
||||
|
||||
get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list)
|
||||
|
||||
mock_catalog_service_client.assert_called_once()
|
||||
_, kwargs = mock_catalog_service_client.call_args
|
||||
client_info = kwargs["client_info"]
|
||||
assert client_info.user_agent == expected_ua
|
||||
|
||||
@@ -44,9 +44,11 @@ class TestBigQueryCredentials:
|
||||
|
||||
# Verify that the credentials are properly stored and attributes are extracted
|
||||
assert config.credentials == auth_creds
|
||||
assert config.client_id is None
|
||||
assert config.client_secret is None
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
|
||||
assert config.scopes == [
|
||||
"https://www.googleapis.com/auth/bigquery",
|
||||
"https://www.googleapis.com/auth/dataplex",
|
||||
]
|
||||
|
||||
def test_valid_credentials_object_oauth2_credentials(self):
|
||||
"""Test that providing valid Credentials object works correctly with
|
||||
@@ -86,7 +88,10 @@ class TestBigQueryCredentials:
|
||||
assert config.credentials is None
|
||||
assert config.client_id == "test_client_id"
|
||||
assert config.client_secret == "test_client_secret"
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
|
||||
assert config.scopes == [
|
||||
"https://www.googleapis.com/auth/bigquery",
|
||||
"https://www.googleapis.com/auth/dataplex",
|
||||
]
|
||||
|
||||
def test_valid_client_id_secret_pair_w_scope(self):
|
||||
"""Test that providing client ID and secret with explicit scopes works.
|
||||
@@ -128,7 +133,10 @@ class TestBigQueryCredentials:
|
||||
assert config.credentials is None
|
||||
assert config.client_id == "test_client_id"
|
||||
assert config.client_secret == "test_client_secret"
|
||||
assert config.scopes == ["https://www.googleapis.com/auth/bigquery"]
|
||||
assert config.scopes == [
|
||||
"https://www.googleapis.com/auth/bigquery",
|
||||
"https://www.googleapis.com/auth/dataplex",
|
||||
]
|
||||
|
||||
def test_missing_client_secret_raises_error(self):
|
||||
"""Test that missing client secret raises appropriate validation error.
|
||||
|
||||
@@ -0,0 +1,448 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
# Mock google.genai and pydantic if not available, before importing google.adk modules
|
||||
try:
|
||||
import google.genai
|
||||
except ImportError:
|
||||
m = mock.MagicMock()
|
||||
m.__path__ = []
|
||||
sys.modules["google.genai"] = m
|
||||
sys.modules["google.genai.types"] = mock.MagicMock()
|
||||
sys.modules["google.genai.errors"] = mock.MagicMock()
|
||||
|
||||
try:
|
||||
import pydantic
|
||||
except ImportError:
|
||||
m_pydantic = mock.MagicMock()
|
||||
|
||||
class MockBaseModel:
|
||||
pass
|
||||
|
||||
m_pydantic.BaseModel = MockBaseModel
|
||||
sys.modules["pydantic"] = m_pydantic
|
||||
|
||||
try:
|
||||
import fastapi
|
||||
import fastapi.openapi.models
|
||||
except ImportError:
|
||||
m_fastapi = mock.MagicMock()
|
||||
m_fastapi.openapi.models = mock.MagicMock()
|
||||
sys.modules["fastapi"] = m_fastapi
|
||||
sys.modules["fastapi.openapi"] = mock.MagicMock()
|
||||
sys.modules["fastapi.openapi.models"] = mock.MagicMock()
|
||||
|
||||
|
||||
from google.adk.tools.bigquery import search_tool
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.api_core import exceptions as api_exceptions
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import dataplex_v1
|
||||
|
||||
|
||||
def _mock_creds():
|
||||
return mock.create_autospec(Credentials, instance=True)
|
||||
|
||||
|
||||
def _mock_settings(app_name: str | None = "test-app"):
|
||||
return BigQueryToolConfig(application_name=app_name)
|
||||
|
||||
|
||||
def _mock_search_entries_response(results: list[dict[str, Any]]):
|
||||
mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse)
|
||||
mock_results = []
|
||||
for r in results:
|
||||
mock_result = mock.create_autospec(
|
||||
dataplex_v1.SearchEntriesResult, instance=True
|
||||
)
|
||||
# Manually attach dataplex_entry since it's not visible in dir() of the proto class
|
||||
mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True)
|
||||
mock_result.dataplex_entry = mock_entry
|
||||
|
||||
mock_entry.name = r.get("name")
|
||||
mock_entry.entry_type = r.get("entry_type")
|
||||
mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z")
|
||||
|
||||
# Manually attach entry_source since it's not visible in dir() of the proto class
|
||||
mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True)
|
||||
mock_entry.entry_source = mock_source
|
||||
|
||||
mock_source.display_name = r.get("display_name")
|
||||
mock_source.resource = r.get("linked_resource")
|
||||
mock_source.description = r.get("description")
|
||||
mock_source.location = r.get("location")
|
||||
mock_results.append(mock_result)
|
||||
mock_response.results = mock_results
|
||||
return mock_response
|
||||
|
||||
|
||||
class TestSearchCatalog(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.mock_dataplex_client = mock.create_autospec(
|
||||
dataplex_v1.CatalogServiceClient, instance=True
|
||||
)
|
||||
|
||||
# Patch get_dataplex_catalog_client
|
||||
self.mock_get_dataplex_client = self.enter_context(
|
||||
mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_dataplex_catalog_client",
|
||||
autospec=True,
|
||||
)
|
||||
)
|
||||
self.mock_get_dataplex_client.return_value = self.mock_dataplex_client
|
||||
self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client
|
||||
|
||||
# Patch SearchEntriesRequest
|
||||
self.mock_search_request = self.enter_context(
|
||||
mock.patch(
|
||||
"google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True
|
||||
)
|
||||
)
|
||||
|
||||
def test_search_catalog_success(self):
|
||||
"""Test the successful path of search_catalog."""
|
||||
creds = _mock_creds()
|
||||
settings = _mock_settings()
|
||||
prompt = "customer data"
|
||||
project_id = "test-project"
|
||||
location = "us"
|
||||
|
||||
mock_api_results = [{
|
||||
"name": "entry1",
|
||||
"entry_type": "TABLE",
|
||||
"display_name": "Cust Table",
|
||||
"linked_resource": (
|
||||
"//bigquery.googleapis.com/projects/p/datasets/d/tables/t1"
|
||||
),
|
||||
"description": "Table 1",
|
||||
"location": "us",
|
||||
}]
|
||||
self.mock_dataplex_client.search_entries.return_value = (
|
||||
_mock_search_entries_response(mock_api_results)
|
||||
)
|
||||
|
||||
result = search_tool.search_catalog(
|
||||
prompt=prompt,
|
||||
project_id=project_id,
|
||||
credentials=creds,
|
||||
settings=settings,
|
||||
location=location,
|
||||
)
|
||||
|
||||
with self.subTest("Test result content"):
|
||||
self.assertEqual(result["status"], "SUCCESS")
|
||||
self.assertLen(result["results"], 1)
|
||||
self.assertEqual(result["results"][0]["name"], "entry1")
|
||||
self.assertEqual(result["results"][0]["display_name"], "Cust Table")
|
||||
|
||||
with self.subTest("Test mock calls"):
|
||||
self.mock_get_dataplex_client.assert_called_once_with(
|
||||
credentials=creds, user_agent=["test-app", "search_catalog"]
|
||||
)
|
||||
|
||||
expected_query = (
|
||||
'(customer data) AND projectid="test-project" AND system=BIGQUERY'
|
||||
)
|
||||
self.mock_search_request.assert_called_once_with(
|
||||
name=f"projects/{project_id}/locations/us",
|
||||
query=expected_query,
|
||||
page_size=10,
|
||||
semantic_search=True,
|
||||
)
|
||||
self.mock_dataplex_client.search_entries.assert_called_once_with(
|
||||
request=self.mock_search_request.return_value
|
||||
)
|
||||
|
||||
def test_search_catalog_no_project_id(self):
|
||||
"""Test search_catalog with missing project_id."""
|
||||
result = search_tool.search_catalog(
|
||||
prompt="test",
|
||||
project_id="",
|
||||
credentials=_mock_creds(),
|
||||
settings=_mock_settings(),
|
||||
location="us",
|
||||
)
|
||||
self.assertEqual(result["status"], "ERROR")
|
||||
self.assertIn("project_id must be provided", result["error_details"])
|
||||
self.mock_get_dataplex_client.assert_not_called()
|
||||
|
||||
def test_search_catalog_api_error(self):
|
||||
"""Test search_catalog handling API exceptions."""
|
||||
self.mock_dataplex_client.search_entries.side_effect = (
|
||||
api_exceptions.BadRequest("Invalid query")
|
||||
)
|
||||
|
||||
result = search_tool.search_catalog(
|
||||
prompt="test",
|
||||
project_id="test-project",
|
||||
credentials=_mock_creds(),
|
||||
settings=_mock_settings(),
|
||||
location="us",
|
||||
)
|
||||
self.assertEqual(result["status"], "ERROR")
|
||||
self.assertIn(
|
||||
"Dataplex API Error: 400 Invalid query", result["error_details"]
|
||||
)
|
||||
|
||||
def test_search_catalog_other_exception(self):
|
||||
"""Test search_catalog handling unexpected exceptions."""
|
||||
self.mock_get_dataplex_client.side_effect = Exception(
|
||||
"Something went wrong"
|
||||
)
|
||||
|
||||
result = search_tool.search_catalog(
|
||||
prompt="test",
|
||||
project_id="test-project",
|
||||
credentials=_mock_creds(),
|
||||
settings=_mock_settings(),
|
||||
location="us",
|
||||
)
|
||||
self.assertEqual(result["status"], "ERROR")
|
||||
self.assertIn("Something went wrong", result["error_details"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'),
|
||||
(
|
||||
"multi_project_filter",
|
||||
"p",
|
||||
["p1", "p2"],
|
||||
None,
|
||||
None,
|
||||
'(projectid="p1" OR projectid="p2")',
|
||||
),
|
||||
("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'),
|
||||
(
|
||||
"multi_type_filter",
|
||||
"p",
|
||||
None,
|
||||
None,
|
||||
["TABLE", "DATASET"],
|
||||
'(type="TABLE" OR type="DATASET")',
|
||||
),
|
||||
(
|
||||
"project_and_dataset_filters",
|
||||
"inventory",
|
||||
["proj1", "proj2"],
|
||||
["dsetA"],
|
||||
None,
|
||||
(
|
||||
'(projectid="proj1" OR projectid="proj2") AND'
|
||||
' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"'
|
||||
' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")'
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_search_catalog_query_construction(
|
||||
self, prompt, project_ids, dataset_ids, types, expected_query_part
|
||||
):
|
||||
"""Test different query constructions based on filters."""
|
||||
search_tool.search_catalog(
|
||||
prompt=prompt,
|
||||
project_id="test-project",
|
||||
credentials=_mock_creds(),
|
||||
settings=_mock_settings(),
|
||||
location="us",
|
||||
project_ids_filter=project_ids,
|
||||
dataset_ids_filter=dataset_ids,
|
||||
types_filter=types,
|
||||
)
|
||||
|
||||
self.mock_search_request.assert_called_once()
|
||||
_, kwargs = self.mock_search_request.call_args
|
||||
query = kwargs["query"]
|
||||
|
||||
if prompt:
|
||||
assert f"({prompt})" in query
|
||||
assert "system=BIGQUERY" in query
|
||||
assert expected_query_part in query
|
||||
|
||||
def test_search_catalog_no_app_name(self):
|
||||
"""Test search_catalog when settings.application_name is None."""
|
||||
creds = _mock_creds()
|
||||
settings = _mock_settings(app_name=None)
|
||||
search_tool.search_catalog(
|
||||
prompt="test",
|
||||
project_id="test-project",
|
||||
credentials=creds,
|
||||
settings=settings,
|
||||
location="us",
|
||||
)
|
||||
|
||||
self.mock_get_dataplex_client.assert_called_once_with(
|
||||
credentials=creds, user_agent=[None, "search_catalog"]
|
||||
)
|
||||
|
||||
def test_search_catalog_multi_project_filter_semantic(self):
|
||||
"""Test semantic search with a multi-project filter."""
|
||||
creds = _mock_creds()
|
||||
settings = _mock_settings()
|
||||
prompt = "What datasets store user profiles?"
|
||||
project_id = "main-project"
|
||||
project_filters = ["user-data-proj", "shared-infra-proj"]
|
||||
location = "global"
|
||||
|
||||
self.mock_dataplex_client.search_entries.return_value = (
|
||||
_mock_search_entries_response([])
|
||||
)
|
||||
|
||||
search_tool.search_catalog(
|
||||
prompt=prompt,
|
||||
project_id=project_id,
|
||||
credentials=creds,
|
||||
settings=settings,
|
||||
location=location,
|
||||
project_ids_filter=project_filters,
|
||||
types_filter=["DATASET"],
|
||||
)
|
||||
|
||||
expected_query = (
|
||||
f"({prompt}) AND "
|
||||
'(projectid="user-data-proj" OR projectid="shared-infra-proj") AND '
|
||||
'type="DATASET" AND system=BIGQUERY'
|
||||
)
|
||||
self.mock_search_request.assert_called_once_with(
|
||||
name=f"projects/{project_id}/locations/{location}",
|
||||
query=expected_query,
|
||||
page_size=10,
|
||||
semantic_search=True,
|
||||
)
|
||||
self.mock_dataplex_client.search_entries.assert_called_once()
|
||||
|
||||
def test_search_catalog_natural_language_semantic(self):
|
||||
"""Test natural language prompts with semantic search enabled and check output."""
|
||||
creds = _mock_creds()
|
||||
settings = _mock_settings()
|
||||
prompt = "Find tables about football matches"
|
||||
project_id = "sports-analytics"
|
||||
location = "europe-west1"
|
||||
|
||||
# Mock the results that the API would return for this semantic query
|
||||
mock_api_results = [
|
||||
{
|
||||
"name": (
|
||||
"projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1"
|
||||
),
|
||||
"display_name": "uk_football_premiership",
|
||||
"entry_type": (
|
||||
"projects/655216118709/locations/global/entryTypes/bigquery-table"
|
||||
),
|
||||
"linked_resource": (
|
||||
"//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership"
|
||||
),
|
||||
"description": "Stats for UK Premier League matches.",
|
||||
"location": "europe-west1",
|
||||
},
|
||||
{
|
||||
"name": (
|
||||
"projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2"
|
||||
),
|
||||
"display_name": "serie_a_matches",
|
||||
"entry_type": (
|
||||
"projects/655216118709/locations/global/entryTypes/bigquery-table"
|
||||
),
|
||||
"linked_resource": (
|
||||
"//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a"
|
||||
),
|
||||
"description": "Italian Serie A football results.",
|
||||
"location": "europe-west1",
|
||||
},
|
||||
]
|
||||
self.mock_dataplex_client.search_entries.return_value = (
|
||||
_mock_search_entries_response(mock_api_results)
|
||||
)
|
||||
|
||||
result = search_tool.search_catalog(
|
||||
prompt=prompt,
|
||||
project_id=project_id,
|
||||
credentials=creds,
|
||||
settings=settings,
|
||||
location=location,
|
||||
)
|
||||
|
||||
with self.subTest("Query Construction"):
|
||||
# Assert the request was made as expected
|
||||
expected_query = (
|
||||
f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY'
|
||||
)
|
||||
self.mock_search_request.assert_called_once_with(
|
||||
name=f"projects/{project_id}/locations/{location}",
|
||||
query=expected_query,
|
||||
page_size=10,
|
||||
semantic_search=True,
|
||||
)
|
||||
self.mock_dataplex_client.search_entries.assert_called_once()
|
||||
|
||||
with self.subTest("Response Processing"):
|
||||
# Assert the output is processed correctly
|
||||
self.assertEqual(result["status"], "SUCCESS")
|
||||
self.assertLen(result["results"], 2)
|
||||
self.assertEqual(
|
||||
result["results"][0]["display_name"], "uk_football_premiership"
|
||||
)
|
||||
self.assertEqual(result["results"][1]["display_name"], "serie_a_matches")
|
||||
self.assertIn("UK Premier League", result["results"][0]["description"])
|
||||
|
||||
def test_search_catalog_default_location(self):
|
||||
"""Test search_catalog fallback to global location when None is provided."""
|
||||
creds = _mock_creds()
|
||||
settings = _mock_settings()
|
||||
# settings.location is None by default
|
||||
|
||||
self.mock_dataplex_client.search_entries.return_value = (
|
||||
_mock_search_entries_response([])
|
||||
)
|
||||
|
||||
search_tool.search_catalog(
|
||||
prompt="test",
|
||||
project_id="test-project",
|
||||
credentials=creds,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
self.mock_search_request.assert_called_once()
|
||||
_, kwargs = self.mock_search_request.call_args
|
||||
name_arg = kwargs["name"]
|
||||
self.assertIn("locations/global", name_arg)
|
||||
|
||||
def test_search_catalog_settings_location(self):
|
||||
"""Test search_catalog uses settings.location when provided."""
|
||||
creds = _mock_creds()
|
||||
settings = BigQueryToolConfig(location="eu")
|
||||
|
||||
self.mock_dataplex_client.search_entries.return_value = (
|
||||
_mock_search_entries_response([])
|
||||
)
|
||||
|
||||
search_tool.search_catalog(
|
||||
prompt="test",
|
||||
project_id="test-project",
|
||||
credentials=creds,
|
||||
settings=settings,
|
||||
)
|
||||
|
||||
self.mock_search_request.assert_called_once()
|
||||
_, kwargs = self.mock_search_request.call_args
|
||||
name_arg = kwargs["name"]
|
||||
self.assertIn("locations/eu", name_arg)
|
||||
@@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default():
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == 10
|
||||
assert len(tools) == 11
|
||||
assert all([isinstance(tool, GoogleTool) for tool in tools])
|
||||
|
||||
expected_tool_names = set([
|
||||
@@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default():
|
||||
"forecast",
|
||||
"analyze_contribution",
|
||||
"detect_anomalies",
|
||||
"search_catalog",
|
||||
])
|
||||
actual_tool_names = set([tool.name for tool in tools])
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
Reference in New Issue
Block a user