diff --git a/contributing/samples/bigquery/README.md b/contributing/samples/bigquery/README.md index 3ed97432..fc3f8610 100644 --- a/contributing/samples/bigquery/README.md +++ b/contributing/samples/bigquery/README.md @@ -55,6 +55,9 @@ distributed via the `google.adk.tools.bigquery` module. These tools include: `ARIMA_PLUS` model and then querying it with `ML.DETECT_ANOMALIES` to detect time series data anomalies. +11. `search_catalog` + Searches for data entries across projects using the Dataplex Catalog. This allows discovery of datasets, tables, and other assets. + ## How to use Set up environment variables in your `.env` file for using @@ -159,3 +162,4 @@ the necessary access tokens to call BigQuery APIs on their behalf. * which tables exist in the ml_datasets dataset? * show more details about the penguins table * compute penguins population per island. +* are there any tables related to animals in project ? diff --git a/pyproject.toml b/pyproject.toml index 9bec96cb..a1f136d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "google-cloud-bigquery-storage>=2.0.0", "google-cloud-bigquery>=2.2.0", "google-cloud-bigtable>=2.32.0", # For Bigtable database + "google-cloud-dataplex>=1.7.0,<3.0.0", # For Dataplex Catalog Search tool "google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool "google-cloud-pubsub>=2.0.0, <3.0.0", # For Pub/Sub Tool "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index fa23c74c..958ce9d7 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -19,6 +19,10 @@ from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" +BIGQUERY_SCOPES = [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", +] BIGQUERY_DEFAULT_SCOPE = ["https://www.googleapis.com/auth/bigquery"] @@ -34,8 +38,8 @@ class BigQueryCredentialsConfig(BaseGoogleCredentialsConfig): super().__post_init__() if not self.scopes: - self.scopes = BIGQUERY_DEFAULT_SCOPE - + self.scopes = BIGQUERY_SCOPES + # Set the token cache key self._token_cache_key = BIGQUERY_TOKEN_CACHE_KEY return self diff --git a/src/google/adk/tools/bigquery/bigquery_toolset.py b/src/google/adk/tools/bigquery/bigquery_toolset.py index 1a748b71..dba5f8ee 100644 --- a/src/google/adk/tools/bigquery/bigquery_toolset.py +++ b/src/google/adk/tools/bigquery/bigquery_toolset.py @@ -24,6 +24,7 @@ from typing_extensions import override from . import data_insights_tool from . import metadata_tool from . import query_tool +from . import search_tool from ...features import experimental from ...features import FeatureName from ...tools.base_tool import BaseTool @@ -87,6 +88,7 @@ class BigQueryToolset(BaseToolset): query_tool.analyze_contribution, query_tool.detect_anomalies, data_insights_tool.ask_data_insights, + search_tool.search_catalog, ] ] diff --git a/src/google/adk/tools/bigquery/client.py b/src/google/adk/tools/bigquery/client.py index d57c0c80..2cb4e67c 100644 --- a/src/google/adk/tools/bigquery/client.py +++ b/src/google/adk/tools/bigquery/client.py @@ -14,19 +14,22 @@ from __future__ import annotations +from typing import List from typing import Optional +from typing import Union import google.api_core.client_info +from google.api_core.gapic_v1 import client_info as gapic_client_info from google.auth.credentials import Credentials from google.cloud import bigquery +from google.cloud import dataplex_v1 from ... import version -USER_AGENT = f"adk-bigquery-tool google-adk/{version.__version__}" - - -from typing import List -from typing import Union +USER_AGENT_BASE = f"google-adk/{version.__version__}" +BQ_USER_AGENT = f"adk-bigquery-tool {USER_AGENT_BASE}" +DP_USER_AGENT = f"adk-dataplex-tool {USER_AGENT_BASE}" +USER_AGENT = BQ_USER_AGENT def get_bigquery_client( @@ -48,7 +51,7 @@ def get_bigquery_client( A BigQuery client. """ - user_agents = [USER_AGENT] + user_agents = [BQ_USER_AGENT] if user_agent: if isinstance(user_agent, str): user_agents.append(user_agent) @@ -67,3 +70,33 @@ def get_bigquery_client( ) return bigquery_client + + +def get_dataplex_catalog_client( + *, + credentials: Credentials, + user_agent: Optional[Union[str, List[str]]] = None, +) -> dataplex_v1.CatalogServiceClient: + """Get a Dataplex CatalogServiceClient with minimal necessary arguments. + + Args: + credentials: The credentials to use for the request. + user_agent: Additional user agent string(s) to append. + + Returns: + A Dataplex Client. + """ + + user_agents = [DP_USER_AGENT] + if user_agent: + if isinstance(user_agent, str): + user_agents.append(user_agent) + else: + user_agents.extend([ua for ua in user_agent if ua]) + + client_info = gapic_client_info.ClientInfo(user_agent=" ".join(user_agents)) + + return dataplex_v1.CatalogServiceClient( + credentials=credentials, + client_info=client_info, + ) diff --git a/src/google/adk/tools/bigquery/search_tool.py b/src/google/adk/tools/bigquery/search_tool.py new file mode 100644 index 00000000..0bf01d5a --- /dev/null +++ b/src/google/adk/tools/bigquery/search_tool.py @@ -0,0 +1,179 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Any + +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + +from . import client +from .config import BigQueryToolConfig + + +def _construct_search_query_helper( + predicate: str, operator: str, items: list[str] +) -> str: + """Constructs a search query part for a specific predicate and items.""" + if not items: + return "" + + clauses = [f'{predicate}{operator}"{item}"' for item in items] + return "(" + " OR ".join(clauses) + ")" if len(items) > 1 else clauses[0] + + +def search_catalog( + prompt: str, + project_id: str, + *, + credentials: Credentials, + settings: BigQueryToolConfig, + location: str | None = None, + page_size: int = 10, + project_ids_filter: list[str] | None = None, + dataset_ids_filter: list[str] | None = None, + types_filter: list[str] | None = None, +) -> dict[str, Any]: + """Searches for BigQuery assets within Dataplex. + + Args: + prompt: The base search query (natural language or keywords). + project_id: The Google Cloud project ID to scope the search. + credentials: Credentials for the request. + settings: BigQuery tool settings. + location: The Dataplex location to use. + page_size: Maximum number of results. + project_ids_filter: Specific project IDs to include in the search results. + If None, defaults to the scoping project_id. + dataset_ids_filter: BigQuery dataset IDs to filter by. + types_filter: Entry types to filter by (e.g., BigQueryEntryType.TABLE, + BigQueryEntryType.DATASET). + + Returns: + Search results or error. The "results" list contains items with: + - name: The Dataplex Entry name (e.g., + "projects/p/locations/l/entryGroups/g/entries/e"). + - linked_resource: The underlying BigQuery resource name (e.g., + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t"). + - display_name, entry_type, description, location, update_time. + + Examples: + Search for tables related to customer data: + + >>> search_catalog( + ... prompt="Search for tables related to customer data", + ... project_id="my-project", + ... credentials=creds, + ... settings=settings + ... ) + { + "status": "SUCCESS", + "results": [ + { + "name": + "projects/my-project/locations/us/entryGroups/@bigquery/entries/entry-id", + "display_name": "customer_table", + "entry_type": + "projects/p/locations/l/entryTypes/bigquery-table", + "linked_resource": + "//bigquery.googleapis.com/projects/my-project/datasets/d/tables/customer_table", + "description": "Table containing customer details.", + "location": "us", + "update_time": "2024-01-01 12:00:00+00:00" + } + ] + } + """ + + try: + if not project_id: + return { + "status": "ERROR", + "error_details": "project_id must be provided.", + } + + with client.get_dataplex_catalog_client( + credentials=credentials, + user_agent=[settings.application_name, "search_catalog"], + ) as dataplex_client: + query_parts = [] + if prompt: + query_parts.append(f"({prompt})") + + # Filter by project IDs + projects_to_filter = ( + project_ids_filter if project_ids_filter else [project_id] + ) + if projects_to_filter: + query_parts.append( + _construct_search_query_helper("projectid", "=", projects_to_filter) + ) + + # Filter by dataset IDs + if dataset_ids_filter: + dataset_resource_filters = [] + for pid in projects_to_filter: + for did in dataset_ids_filter: + dataset_resource_filters.append( + f'linked_resource:"//bigquery.googleapis.com/projects/{pid}/datasets/{did}/*"' + ) + if dataset_resource_filters: + query_parts.append(f"({' OR '.join(dataset_resource_filters)})") + # Filter by entry types + if types_filter: + query_parts.append( + _construct_search_query_helper("type", "=", types_filter) + ) + + # Always scope to BigQuery system + query_parts.append("system=BIGQUERY") + + full_query = " AND ".join(filter(None, query_parts)) + + search_location = location or settings.location or "global" + search_scope = f"projects/{project_id}/locations/{search_location}" + + request = dataplex_v1.SearchEntriesRequest( + name=search_scope, + query=full_query, + page_size=page_size, + semantic_search=True, + ) + + response = dataplex_client.search_entries(request=request) + + results = [] + for result in response.results: + entry = result.dataplex_entry + source = entry.entry_source + results.append({ + "name": entry.name, + "display_name": source.display_name or "", + "entry_type": entry.entry_type, + "update_time": str(entry.update_time), + "linked_resource": source.resource or "", + "description": source.description or "", + "location": source.location or "", + }) + return {"status": "SUCCESS", "results": results} + + except api_exceptions.GoogleAPICallError as e: + logging.exception("search_catalog tool: API call failed") + return {"status": "ERROR", "error_details": f"Dataplex API Error: {e}"} + except Exception as e: + logging.exception("search_catalog tool: Unexpected error") + return {"status": "ERROR", "error_details": repr(e)} diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index 80a97f8f..d8d5e726 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -18,9 +18,13 @@ import os from unittest import mock import google.adk +from google.adk.tools.bigquery.client import DP_USER_AGENT from google.adk.tools.bigquery.client import get_bigquery_client +from google.adk.tools.bigquery.client import get_dataplex_catalog_client +from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.auth.exceptions import DefaultCredentialsError +from google.cloud import dataplex_v1 from google.cloud.bigquery import client as bigquery_client from google.oauth2.credentials import Credentials @@ -201,3 +205,74 @@ def test_bigquery_client_location_custom(): # Verify that the client has the desired project set assert client.project == "test-gcp-project" assert client.location == "us-central1" + + +# Tests for Dataplex Catalog Client +# ------------------------------------------------------------------------------ + + +# Mock the CatalogServiceClient class directly +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_default(mock_catalog_service_client): + """Test get_dataplex_catalog_client with default user agent.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + + client = get_dataplex_catalog_client(credentials=mock_creds) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + + assert kwargs["credentials"] == mock_creds + client_info = kwargs["client_info"] + assert isinstance(client_info, gapic_client_info.ClientInfo) + assert client_info.user_agent == DP_USER_AGENT + + # Ensure the function returns the mock instance + assert client == mock_catalog_service_client.return_value + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_str(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent string.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua = "catalog_ua/1.0" + expected_ua = f"{DP_USER_AGENT} {custom_ua}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list(mock_catalog_service_client): + """Test get_dataplex_catalog_client with a custom user agent list.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} {' '.join(custom_ua_list)}" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua + + +@mock.patch.object(dataplex_v1, "CatalogServiceClient", autospec=True) +def test_dataplex_client_custom_user_agent_list_with_none( + mock_catalog_service_client, +): + """Test get_dataplex_catalog_client with a list containing None.""" + mock_creds = mock.create_autospec(Credentials, instance=True) + custom_ua_list = ["catalog_ua", None, "catalog_ua_2.0"] + expected_ua = f"{DP_USER_AGENT} catalog_ua catalog_ua_2.0" + + get_dataplex_catalog_client(credentials=mock_creds, user_agent=custom_ua_list) + + mock_catalog_service_client.assert_called_once() + _, kwargs = mock_catalog_service_client.call_args + client_info = kwargs["client_info"] + assert client_info.user_agent == expected_ua diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 9cf8c9e4..e2066292 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -44,9 +44,11 @@ class TestBigQueryCredentials: # Verify that the credentials are properly stored and attributes are extracted assert config.credentials == auth_creds - assert config.client_id is None assert config.client_secret is None - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_credentials_object_oauth2_credentials(self): """Test that providing valid Credentials object works correctly with @@ -86,7 +88,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_valid_client_id_secret_pair_w_scope(self): """Test that providing client ID and secret with explicit scopes works. @@ -128,7 +133,10 @@ class TestBigQueryCredentials: assert config.credentials is None assert config.client_id == "test_client_id" assert config.client_secret == "test_client_secret" - assert config.scopes == ["https://www.googleapis.com/auth/bigquery"] + assert config.scopes == [ + "https://www.googleapis.com/auth/bigquery", + "https://www.googleapis.com/auth/dataplex", + ] def test_missing_client_secret_raises_error(self): """Test that missing client secret raises appropriate validation error. diff --git a/tests/unittests/tools/bigquery/test_bigquery_search_tool.py b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py new file mode 100644 index 00000000..0ccdc9e1 --- /dev/null +++ b/tests/unittests/tools/bigquery/test_bigquery_search_tool.py @@ -0,0 +1,448 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +import unittest +from unittest import mock + +from absl.testing import parameterized + +# Mock google.genai and pydantic if not available, before importing google.adk modules +try: + import google.genai +except ImportError: + m = mock.MagicMock() + m.__path__ = [] + sys.modules["google.genai"] = m + sys.modules["google.genai.types"] = mock.MagicMock() + sys.modules["google.genai.errors"] = mock.MagicMock() + +try: + import pydantic +except ImportError: + m_pydantic = mock.MagicMock() + + class MockBaseModel: + pass + + m_pydantic.BaseModel = MockBaseModel + sys.modules["pydantic"] = m_pydantic + +try: + import fastapi + import fastapi.openapi.models +except ImportError: + m_fastapi = mock.MagicMock() + m_fastapi.openapi.models = mock.MagicMock() + sys.modules["fastapi"] = m_fastapi + sys.modules["fastapi.openapi"] = mock.MagicMock() + sys.modules["fastapi.openapi.models"] = mock.MagicMock() + + +from google.adk.tools.bigquery import search_tool +from google.adk.tools.bigquery.config import BigQueryToolConfig +from google.api_core import exceptions as api_exceptions +from google.auth.credentials import Credentials +from google.cloud import dataplex_v1 + + +def _mock_creds(): + return mock.create_autospec(Credentials, instance=True) + + +def _mock_settings(app_name: str | None = "test-app"): + return BigQueryToolConfig(application_name=app_name) + + +def _mock_search_entries_response(results: list[dict[str, Any]]): + mock_response = mock.MagicMock(spec=dataplex_v1.SearchEntriesResponse) + mock_results = [] + for r in results: + mock_result = mock.create_autospec( + dataplex_v1.SearchEntriesResult, instance=True + ) + # Manually attach dataplex_entry since it's not visible in dir() of the proto class + mock_entry = mock.create_autospec(dataplex_v1.Entry, instance=True) + mock_result.dataplex_entry = mock_entry + + mock_entry.name = r.get("name") + mock_entry.entry_type = r.get("entry_type") + mock_entry.update_time = r.get("update_time", "2026-01-14T05:00:00Z") + + # Manually attach entry_source since it's not visible in dir() of the proto class + mock_source = mock.create_autospec(dataplex_v1.EntrySource, instance=True) + mock_entry.entry_source = mock_source + + mock_source.display_name = r.get("display_name") + mock_source.resource = r.get("linked_resource") + mock_source.description = r.get("description") + mock_source.location = r.get("location") + mock_results.append(mock_result) + mock_response.results = mock_results + return mock_response + + +class TestSearchCatalog(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.mock_dataplex_client = mock.create_autospec( + dataplex_v1.CatalogServiceClient, instance=True + ) + + # Patch get_dataplex_catalog_client + self.mock_get_dataplex_client = self.enter_context( + mock.patch( + "google.adk.tools.bigquery.client.get_dataplex_catalog_client", + autospec=True, + ) + ) + self.mock_get_dataplex_client.return_value = self.mock_dataplex_client + self.mock_dataplex_client.__enter__.return_value = self.mock_dataplex_client + + # Patch SearchEntriesRequest + self.mock_search_request = self.enter_context( + mock.patch( + "google.cloud.dataplex_v1.SearchEntriesRequest", autospec=True + ) + ) + + def test_search_catalog_success(self): + """Test the successful path of search_catalog.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "customer data" + project_id = "test-project" + location = "us" + + mock_api_results = [{ + "name": "entry1", + "entry_type": "TABLE", + "display_name": "Cust Table", + "linked_resource": ( + "//bigquery.googleapis.com/projects/p/datasets/d/tables/t1" + ), + "description": "Table 1", + "location": "us", + }] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Test result content"): + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 1) + self.assertEqual(result["results"][0]["name"], "entry1") + self.assertEqual(result["results"][0]["display_name"], "Cust Table") + + with self.subTest("Test mock calls"): + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=["test-app", "search_catalog"] + ) + + expected_query = ( + '(customer data) AND projectid="test-project" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/us", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once_with( + request=self.mock_search_request.return_value + ) + + def test_search_catalog_no_project_id(self): + """Test search_catalog with missing project_id.""" + result = search_tool.search_catalog( + prompt="test", + project_id="", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("project_id must be provided", result["error_details"]) + self.mock_get_dataplex_client.assert_not_called() + + def test_search_catalog_api_error(self): + """Test search_catalog handling API exceptions.""" + self.mock_dataplex_client.search_entries.side_effect = ( + api_exceptions.BadRequest("Invalid query") + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn( + "Dataplex API Error: 400 Invalid query", result["error_details"] + ) + + def test_search_catalog_other_exception(self): + """Test search_catalog handling unexpected exceptions.""" + self.mock_get_dataplex_client.side_effect = Exception( + "Something went wrong" + ) + + result = search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + ) + self.assertEqual(result["status"], "ERROR") + self.assertIn("Something went wrong", result["error_details"]) + + @parameterized.named_parameters( + ("project_filter", "p", ["proj1"], None, None, 'projectid="proj1"'), + ( + "multi_project_filter", + "p", + ["p1", "p2"], + None, + None, + '(projectid="p1" OR projectid="p2")', + ), + ("type_filter", "p", None, None, ["TABLE"], 'type="TABLE"'), + ( + "multi_type_filter", + "p", + None, + None, + ["TABLE", "DATASET"], + '(type="TABLE" OR type="DATASET")', + ), + ( + "project_and_dataset_filters", + "inventory", + ["proj1", "proj2"], + ["dsetA"], + None, + ( + '(projectid="proj1" OR projectid="proj2") AND' + ' (linked_resource:"//bigquery.googleapis.com/projects/proj1/datasets/dsetA/*"' + ' OR linked_resource:"//bigquery.googleapis.com/projects/proj2/datasets/dsetA/*")' + ), + ), + ) + def test_search_catalog_query_construction( + self, prompt, project_ids, dataset_ids, types, expected_query_part + ): + """Test different query constructions based on filters.""" + search_tool.search_catalog( + prompt=prompt, + project_id="test-project", + credentials=_mock_creds(), + settings=_mock_settings(), + location="us", + project_ids_filter=project_ids, + dataset_ids_filter=dataset_ids, + types_filter=types, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + query = kwargs["query"] + + if prompt: + assert f"({prompt})" in query + assert "system=BIGQUERY" in query + assert expected_query_part in query + + def test_search_catalog_no_app_name(self): + """Test search_catalog when settings.application_name is None.""" + creds = _mock_creds() + settings = _mock_settings(app_name=None) + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + location="us", + ) + + self.mock_get_dataplex_client.assert_called_once_with( + credentials=creds, user_agent=[None, "search_catalog"] + ) + + def test_search_catalog_multi_project_filter_semantic(self): + """Test semantic search with a multi-project filter.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "What datasets store user profiles?" + project_id = "main-project" + project_filters = ["user-data-proj", "shared-infra-proj"] + location = "global" + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + project_ids_filter=project_filters, + types_filter=["DATASET"], + ) + + expected_query = ( + f"({prompt}) AND " + '(projectid="user-data-proj" OR projectid="shared-infra-proj") AND ' + 'type="DATASET" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + def test_search_catalog_natural_language_semantic(self): + """Test natural language prompts with semantic search enabled and check output.""" + creds = _mock_creds() + settings = _mock_settings() + prompt = "Find tables about football matches" + project_id = "sports-analytics" + location = "europe-west1" + + # Mock the results that the API would return for this semantic query + mock_api_results = [ + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb1" + ), + "display_name": "uk_football_premiership", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/uk/tables/premiership" + ), + "description": "Stats for UK Premier League matches.", + "location": "europe-west1", + }, + { + "name": ( + "projects/sports-analytics/locations/europe-west1/entryGroups/@bigquery/entries/fb2" + ), + "display_name": "serie_a_matches", + "entry_type": ( + "projects/655216118709/locations/global/entryTypes/bigquery-table" + ), + "linked_resource": ( + "//bigquery.googleapis.com/projects/sports-analytics/datasets/italy/tables/serie_a" + ), + "description": "Italian Serie A football results.", + "location": "europe-west1", + }, + ] + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response(mock_api_results) + ) + + result = search_tool.search_catalog( + prompt=prompt, + project_id=project_id, + credentials=creds, + settings=settings, + location=location, + ) + + with self.subTest("Query Construction"): + # Assert the request was made as expected + expected_query = ( + f'({prompt}) AND projectid="{project_id}" AND system=BIGQUERY' + ) + self.mock_search_request.assert_called_once_with( + name=f"projects/{project_id}/locations/{location}", + query=expected_query, + page_size=10, + semantic_search=True, + ) + self.mock_dataplex_client.search_entries.assert_called_once() + + with self.subTest("Response Processing"): + # Assert the output is processed correctly + self.assertEqual(result["status"], "SUCCESS") + self.assertLen(result["results"], 2) + self.assertEqual( + result["results"][0]["display_name"], "uk_football_premiership" + ) + self.assertEqual(result["results"][1]["display_name"], "serie_a_matches") + self.assertIn("UK Premier League", result["results"][0]["description"]) + + def test_search_catalog_default_location(self): + """Test search_catalog fallback to global location when None is provided.""" + creds = _mock_creds() + settings = _mock_settings() + # settings.location is None by default + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/global", name_arg) + + def test_search_catalog_settings_location(self): + """Test search_catalog uses settings.location when provided.""" + creds = _mock_creds() + settings = BigQueryToolConfig(location="eu") + + self.mock_dataplex_client.search_entries.return_value = ( + _mock_search_entries_response([]) + ) + + search_tool.search_catalog( + prompt="test", + project_id="test-project", + credentials=creds, + settings=settings, + ) + + self.mock_search_request.assert_called_once() + _, kwargs = self.mock_search_request.call_args + name_arg = kwargs["name"] + self.assertIn("locations/eu", name_arg) diff --git a/tests/unittests/tools/bigquery/test_bigquery_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_toolset.py index f1f73aa6..0eced4b1 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_toolset.py +++ b/tests/unittests/tools/bigquery/test_bigquery_toolset.py @@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default(): tools = await toolset.get_tools() assert tools is not None - assert len(tools) == 10 + assert len(tools) == 11 assert all([isinstance(tool, GoogleTool) for tool in tools]) expected_tool_names = set([ @@ -55,6 +55,7 @@ async def test_bigquery_toolset_tools_default(): "forecast", "analyze_contribution", "detect_anomalies", + "search_catalog", ]) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names