diff --git a/contributing/samples/bigquery_data_agent/README.md b/contributing/samples/bigquery_data_agent/README.md deleted file mode 100644 index b04ecae5..00000000 --- a/contributing/samples/bigquery_data_agent/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# BigQuery Data Agent Sample - -This sample agent demonstrates ADK's first-party tools for interacting with -Data Agents based on BigQuery's Conversational Analytics API, distributed via -the `google.adk.tools.bigquery` module. These tools allow you to list, -inspect, and -chat with BigQuery Data Agents using natural language. - -These tools leverage stateful conversations, meaning you can ask follow-up -questions in the same session, and the agent will maintain context. - -## Prerequisites - -1. An active Google Cloud project with BigQuery and Gemini APIs enabled. -2. Google Cloud authentication configured for Application Default Credentials: - ```bash - gcloud auth application-default login - ``` -3. At least one Data Agent created in BigQuery Studio (Data Canvas). These - agents are created and configured in the Google Cloud console and point to - your BigQuery tables or other data sources. - -## Tools Used - -* `list_accessible_data_agents`: Lists Data Agents you have permission to - access in the configured GCP project. -* `get_data_agent_info`: Retrieves details about a specific Data Agent given - its full resource name. -* `ask_data_agent`: Chats with a specific Data Agent using natural language. - This tool maintains conversation state: if you ask multiple - questions to the same agent in one session, it will use the same - conversation, allowing for follow-ups. If you switch agents, a new - conversation will be started for the new agent. - -## How to Run - -1. Navigate to the root of the ADK repository. -2. Run the agent using the ADK CLI: - ```bash - adk run --agent-path contributing/samples/bigquery_data_agent - ``` -3. The CLI will prompt you for input. You can ask questions like the examples - below. - -## Sample prompts - -* "List accessible data agents." -* "Using agent - `projects/my-project/locations/global/dataAgents/sales-agent-123`, who were - my top 3 customers last quarter?" -* "How does that compare to the quarter before?" diff --git a/contributing/samples/bigquery_data_agent/__init__.py b/contributing/samples/bigquery_data_agent/__init__.py deleted file mode 100644 index c48963cd..00000000 --- a/contributing/samples/bigquery_data_agent/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import agent diff --git a/contributing/samples/bigquery_data_agent/agent.py b/contributing/samples/bigquery_data_agent/agent.py deleted file mode 100644 index e1512d79..00000000 --- a/contributing/samples/bigquery_data_agent/agent.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from google.adk.agents import Agent -from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig -from google.adk.tools.bigquery.bigquery_data_agent_toolset import BigQueryDataAgentToolset -from google.adk.tools.bigquery.config import BigQueryToolConfig -from google.adk.tools.bigquery.config import WriteMode -import google.auth - -# Define the desired credential type. -# By default use Application Default Credentials (ADC) from the local -# environment, which can be set up by following -# https://cloud.google.com/docs/authentication/provide-credentials-adc. -CREDENTIALS_TYPE = None - -# Define an appropriate application name -BIGQUERY_AGENT_NAME = "adk_sample_bigquery_agent" - - -# Define BigQuery tool config with write mode set to allowed. Note that this is -# only to demonstrate the full capability of the BigQuery tools. In production -# you may want to change to BLOCKED (default write mode, effectively makes the -# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a -# BigQuery session) write mode. -tool_config = BigQueryToolConfig( - write_mode=WriteMode.ALLOWED, application_name=BIGQUERY_AGENT_NAME -) - -if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: - # Initiaze the tools to do interactive OAuth - # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET - # must be set - credentials_config = BigQueryCredentialsConfig( - client_id=os.getenv("OAUTH_CLIENT_ID"), - client_secret=os.getenv("OAUTH_CLIENT_SECRET"), - ) -elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT: - # Initialize the tools to use the credentials in the service account key. - # If this flow is enabled, make sure to replace the file path with your own - # service account key file - # https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys - creds, _ = google.auth.load_credentials_from_file("service_account_key.json") - credentials_config = BigQueryCredentialsConfig(credentials=creds) -else: - # Initialize the tools to use the application default credentials. - # https://cloud.google.com/docs/authentication/provide-credentials-adc - application_default_credentials, _ = google.auth.default() - credentials_config = BigQueryCredentialsConfig( - credentials=application_default_credentials - ) - -bq_da_toolset = BigQueryDataAgentToolset( - credentials_config=credentials_config, - bigquery_tool_config=tool_config, - tool_filter=[ - "list_accessible_data_agents", - "get_data_agent_info", - "ask_data_agent", - ], -) - -root_agent = Agent( - name="bigquery_data_agent", - model="gemini-2.0-flash", - description="Agent to answer user questions using BigQuery Data Agents.", - instruction=( - "## Persona\nYou are a helpful assistant that uses BigQuery Data Agents" - " to answer user questions about their data.\n\n## Tools\n- You can" - " list available data agents using `list_accessible_data_agents`.\n-" - " You can get information about a specific data agent using" - " `get_data_agent_info`.\n- You can chat with a specific data" - " agent using `ask_data_agent`.\n" - ), - tools=[bq_da_toolset], -) diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 7bbc3dcf..036b56ef 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -24,7 +24,6 @@ from ..utils.env_utils import is_env_enabled class FeatureName(str, Enum): """Feature names.""" - BIG_QUERY_DATA_AGENT_TOOLSET = "BIG_QUERY_DATA_AGENT_TOOLSET" BIG_QUERY_TOOLSET = "BIG_QUERY_TOOLSET" BIG_QUERY_TOOL_CONFIG = "BIG_QUERY_TOOL_CONFIG" BIGTABLE_TOOL_SETTINGS = "BIGTABLE_TOOL_SETTINGS" @@ -68,9 +67,6 @@ class FeatureConfig: # Central registry: FeatureName -> FeatureConfig _FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = { - FeatureName.BIG_QUERY_DATA_AGENT_TOOLSET: FeatureConfig( - FeatureStage.EXPERIMENTAL, default_on=True - ), FeatureName.BIG_QUERY_TOOLSET: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/tools/bigquery/__init__.py b/src/google/adk/tools/bigquery/__init__.py index 2efa8ea4..9e6b1166 100644 --- a/src/google/adk/tools/bigquery/__init__.py +++ b/src/google/adk/tools/bigquery/__init__.py @@ -28,11 +28,9 @@ definition. The rationales to have customized tool are: """ from .bigquery_credentials import BigQueryCredentialsConfig -from .bigquery_data_agent_toolset import BigQueryDataAgentToolset from .bigquery_toolset import BigQueryToolset __all__ = [ "BigQueryToolset", "BigQueryCredentialsConfig", - "BigQueryDataAgentToolset", ] diff --git a/src/google/adk/tools/bigquery/bigquery_data_agent_toolset.py b/src/google/adk/tools/bigquery/bigquery_data_agent_toolset.py deleted file mode 100644 index da6b00be..00000000 --- a/src/google/adk/tools/bigquery/bigquery_data_agent_toolset.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import List -from typing import Optional -from typing import Union - -from google.adk.agents.readonly_context import ReadonlyContext -from typing_extensions import override - -from . import data_insights_tool -from ...features import experimental -from ...features import FeatureName -from ...tools.base_tool import BaseTool -from ...tools.base_toolset import BaseToolset -from ...tools.base_toolset import ToolPredicate -from ...tools.google_tool import GoogleTool -from .bigquery_credentials import BigQueryCredentialsConfig -from .config import BigQueryToolConfig - - -@experimental(FeatureName.BIG_QUERY_DATA_AGENT_TOOLSET) -class BigQueryDataAgentToolset(BaseToolset): - """BigQuery Data Agent Toolset contains tools for interacting with BigQuery data agents.""" - - def __init__( - self, - *, - tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, - credentials_config: Optional[BigQueryCredentialsConfig] = None, - bigquery_tool_config: Optional[BigQueryToolConfig] = None, - ): - super().__init__(tool_filter=tool_filter) - self._credentials_config = credentials_config - self._tool_settings = ( - bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig() - ) - - def _is_tool_selected( - self, tool: BaseTool, readonly_context: ReadonlyContext - ) -> bool: - if self.tool_filter is None: - return True - - if isinstance(self.tool_filter, ToolPredicate): - return self.tool_filter(tool, readonly_context) - - if isinstance(self.tool_filter, list): - return tool.name in self.tool_filter - - return False - - @override - async def get_tools( - self, readonly_context: Optional[ReadonlyContext] = None - ) -> List[BaseTool]: - """Get tools from the toolset.""" - all_tools = [ - GoogleTool( - func=func, - credentials_config=self._credentials_config, - tool_settings=self._tool_settings, - ) - for func in [ - data_insights_tool.list_accessible_data_agents, - data_insights_tool.get_data_agent_info, - data_insights_tool.ask_data_agent, - ] - ] - - return [ - tool - for tool in all_tools - if self._is_tool_selected(tool, readonly_context) - ] - - @override - async def close(self): - pass diff --git a/src/google/adk/tools/bigquery/data_insights_tool.py b/src/google/adk/tools/bigquery/data_insights_tool.py index cd55b9a5..0d7280c2 100644 --- a/src/google/adk/tools/bigquery/data_insights_tool.py +++ b/src/google/adk/tools/bigquery/data_insights_tool.py @@ -20,11 +20,9 @@ from typing import List from google.auth.credentials import Credentials from google.cloud import bigquery -from google.cloud import geminidataanalytics import requests from . import client -from ..tool_context import ToolContext from .config import BigQueryToolConfig @@ -114,7 +112,6 @@ def ask_data_insights( ] } """ - # TODO(huanc): replace this with official client library. try: location = "global" if not credentials.token: @@ -166,329 +163,9 @@ def ask_data_insights( return {"status": "SUCCESS", "response": resp} -def list_accessible_data_agents( - project_id: str, - credentials: Credentials, -) -> Dict[str, Any]: - """Lists accessible data agents in a project. - - Args: - project_id: The project to list agents in. - credentials: The credentials to use for the request. - - Returns: - A dictionary containing the status and a list of data agents with their - detailed information, including name, display_name, description (if - available), create_time, update_time, and data_analytics_agent context, - or error details if the request fails. - - Examples: - >>> list_accessible_data_agents( - ... project_id="my-gcp-project", - ... credentials=credentials, - ... ) - { - "status": "SUCCESS", - "response": [ - { - "name": "projects/my-project/locations/global/dataAgents/agent1", - "display_name": "My Test Agent", - "create_time": {"seconds": 1759358662, "nanos": 473927629}, - "update_time": {"seconds": 1759358663, "nanos": 94541325}, - "data_analytics_agent": { - "published_context": { - "datasource_references": [{ - "bq": { - "table_references": [{ - "project_id": "my-project", - "dataset_id": "dataset1", - "table_id": "table1" - }] - } - }] - } - } - }, - { - "name": "projects/my-project/locations/global/dataAgents/agent2", - "display_name": "", - "description": "Description for Agent 2.", - "create_time": {"seconds": 1750710228, "nanos": 650597312}, - "update_time": {"seconds": 1750710229, "nanos": 437095391}, - "data_analytics_agent": { - "published_context": { - "datasource_references": [{ - "bq": { - "table_references": [{ - "project_id": "another-project", - "dataset_id": "dataset2", - "table_id": "table2" - }] - } - }], - "system_instruction": "You are a helpful assistant.", - "options": {"analysis": {"python": {"enabled": True}}} - } - } - } - ] - } - """ - try: - client = geminidataanalytics.DataAgentServiceClient(credentials=credentials) - request = geminidataanalytics.ListAccessibleDataAgentsRequest( - parent=f"projects/{project_id}/locations/global", - ) - page_result = client.list_accessible_data_agents(request=request) - return { - "status": "SUCCESS", - "response": [str(agent) for agent in page_result], - } - except Exception as ex: # pylint: disable=broad-except - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def get_data_agent_info( - data_agent_name: str, - credentials: Credentials, -) -> Dict[str, Any]: - """Gets a data agent by name. - - Args: - data_agent_name: The name of the agent to get, in format - projects/{project}/locations/{location}/dataAgents/{agent}. - credentials: The credentials to use for the request. - - Returns: - A dictionary containing the status and details of a data agent, - including name, display_name, description (if available), - create_time, update_time, and data_analytics_agent context, - or error details if the request fails. - - Examples: - >>> get_data_agent_info( - ... data_agent_name="projects/my-project/locations/global/dataAgents/agent-1", - ... credentials=credentials, - ... ) - { - "status": "SUCCESS", - "response": { - "name": "projects/my-project/locations/global/dataAgents/agent-1", - "display_name": "My Agent 1", - "description": "Description for Agent 1.", - "create_time": {"seconds": 1750710228, "nanos": 650597312}, - "update_time": {"seconds": 1750710229, "nanos": 437095391}, - "data_analytics_agent": { - "published_context": { - "datasource_references": [{ - "bq": { - "table_references": [{ - "project_id": "my-gcp-project", - "dataset_id": "dataset1", - "table_id": "table1" - }] - } - }], - "system_instruction": "You are a helpful assistant.", - "options": {"analysis": {"python": {"enabled": True}}} - } - } - } - } - """ - try: - client = geminidataanalytics.DataAgentServiceClient(credentials=credentials) - request = geminidataanalytics.GetDataAgentRequest( - name=data_agent_name, - ) - response = client.get_data_agent(request=request) - return { - "status": "SUCCESS", - "response": str(response), - } - except Exception as ex: # pylint: disable=broad-except - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def ask_data_agent( - data_agent_name: str, - query: str, - *, - credentials: Credentials, - tool_context: ToolContext, -) -> Dict[str, Any]: - """Asks a question to a data agent. - - Args: - data_agent_name: The resource name of an existing data agent to ask, - in format projects/{project}/locations/{location}/dataAgents/{agent}. - query: The question to ask the agent. - credentials: The credentials to use for the request. - tool_context: The context for the tool. - - Returns: - A dictionary with two keys: - - 'status': A string indicating the final status (e.g., "SUCCESS"). - - 'response': A list of dictionaries, where each dictionary - represents a step in the agent's execution process (e.g., SQL - generation, data retrieval, final answer). Note that the 'Answer' - step contains a text response which may summarize findings or refer - to previous steps of agent execution, such as 'Data Retrieved', in - which cases, the 'Answer' step does not include the result data. - - Examples: - A query to a data agent, showing the full return structure. - The original question: "Which customer from New York spent the most last - month?" - - >>> ask_data_agent( - ... data_agent_name="projects/my-project/locations/global/dataAgents/sales-agent", - ... query="Which customer from New York spent the most last month?", - ... credentials=credentials, - ... tool_context=tool_context, - ... ) - { - "status": "SUCCESS", - "response": [ - { - "SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... " - }, - { - "Data Retrieved": { - "headers": ["customer_name", "total_spent"], - "rows": [["Jane Doe", 1234.56]], - "summary": "Showing all 1 rows." - } - }, - { - "Answer": "The customer who spent the most was Jane Doe." - } - ] - } - """ - try: - agent_info = get_data_agent_info(data_agent_name, credentials) - if agent_info.get("status") == "ERROR": - return agent_info - client = geminidataanalytics.DataChatServiceClient(credentials=credentials) - parent = data_agent_name.rsplit("/", 2)[0] - conversation_name = None - - if ( - tool_context.state.get("bigquery_data_agent_conv_agent") - == data_agent_name - ): - conversation_name = tool_context.state.get( - "bigquery_data_agent_conv_name" - ) - else: - conversation = geminidataanalytics.Conversation() - conversation.agents = [data_agent_name] - request = geminidataanalytics.CreateConversationRequest( - parent=parent, - conversation=conversation, - ) - response = client.create_conversation(request=request) - conversation_name = response.name - tool_context.state["bigquery_data_agent_conv_agent"] = data_agent_name - tool_context.state["bigquery_data_agent_conv_name"] = conversation_name - - new_user_message = geminidataanalytics.Message() - new_user_message.user_message.text = query - messages = [new_user_message] - - if conversation_name: - conversation_reference = geminidataanalytics.ConversationReference() - conversation_reference.conversation = conversation_name - conversation_reference.data_agent_context.data_agent = data_agent_name - request = geminidataanalytics.ChatRequest( - parent=parent, - messages=messages, - conversation_reference=conversation_reference, - ) - else: - data_agent_context = geminidataanalytics.DataAgentContext() - data_agent_context.data_agent = data_agent_name - request = geminidataanalytics.ChatRequest( - parent=parent, - messages=messages, - data_agent_context=data_agent_context, - ) - stream = client.chat(request=request) - responses = list(stream) - print({ - "status": "SUCCESS", - "response": _process_data_agent_stream(responses), - }) - return { - "status": "SUCCESS", - "response": _process_data_agent_stream(responses), - } - except Exception as ex: # pylint: disable=broad-except - return { - "status": "ERROR", - "error_details": str(ex), - } - - -def _process_result_message(result, max_rows: int) -> dict[str, Any]: - """Processes result message from data agent chat.""" - headers = [f.name for f in result.schema.fields] - all_rows_structs = result.data - total_rows = len(all_rows_structs) - - summary_string = f"Showing all {total_rows} rows." - if total_rows > max_rows: - summary_string = f"Showing the first {max_rows} of {total_rows} total rows." - rows = [] - i = 0 - for row in all_rows_structs: - if i >= max_rows: - break - rows.append([row.get(h) for h in headers]) - i += 1 - return { - "headers": headers, - "rows": rows, - "summary": summary_string, - } - - -def _process_data_agent_stream( - stream: list[geminidataanalytics.Message], max_rows: int = 1000 -) -> list[dict[str, Any]]: - """Processes stream from data agent chat.""" - processed_responses = [] - for i, msg in enumerate(stream): - if msg.system_message: - message = msg.system_message - if message.text.parts: - processed_responses.append({"Answer": "".join(message.text.parts)}) - elif message.data.generated_sql: - processed_responses.append( - {"SQL Generated": message.data.generated_sql} - ) - elif message.data.result and message.data.result.data: - processed_responses.append({ - "Data Retrieved": _process_result_message( - message.data.result, max_rows - ) - }) - elif message.error: - processed_responses.append({"Error": message.error.text}) - return processed_responses - - def _get_stream( url: str, ca_payload: Dict[str, Any], - *, headers: Dict[str, str], max_query_result_rows: int, ) -> List[Dict[str, Any]]: diff --git a/tests/unittests/tools/bigquery/conftest.py b/tests/unittests/tools/bigquery/conftest.py deleted file mode 100644 index 500dc8e8..00000000 --- a/tests/unittests/tools/bigquery/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import types -from unittest import mock - -MOCK_GEMINI_DATA_ANALYTICS = mock.MagicMock() -sys.modules["google.cloud.geminidataanalytics"] = MOCK_GEMINI_DATA_ANALYTICS - -# The mock.patch calls require 'geminidataanalytics' to be an attribute of -# 'google.cloud' module for patching by string to work. -try: - import google.cloud -except ImportError: - sys.modules["google.cloud"] = types.ModuleType("google.cloud") -finally: - sys.modules["google.cloud"].geminidataanalytics = MOCK_GEMINI_DATA_ANALYTICS diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_agent_toolset.py b/tests/unittests/tools/bigquery/test_bigquery_data_agent_toolset.py deleted file mode 100644 index 2735423d..00000000 --- a/tests/unittests/tools/bigquery/test_bigquery_data_agent_toolset.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from unittest import mock - -from google.adk.tools.bigquery import BigQueryCredentialsConfig -from google.adk.tools.bigquery import BigQueryDataAgentToolset -from google.adk.tools.bigquery.config import BigQueryToolConfig -from google.adk.tools.google_tool import GoogleTool -import pytest - - -@pytest.mark.asyncio -async def test_bigquery_data_agent_toolset_tools_default(): - """Test default BigQueryDataAgentToolset. - - This test verifies the behavior of the BigQueryDataAgentToolset when no filter is - specified. - """ - credentials_config = BigQueryCredentialsConfig( - client_id="abc", client_secret="def" - ) - toolset = BigQueryDataAgentToolset( - credentials_config=credentials_config, bigquery_tool_config=None - ) - # Verify that the tool config is initialized to default values. - assert isinstance(toolset._tool_settings, BigQueryToolConfig) # pylint: disable=protected-access - assert toolset._tool_settings.__dict__ == BigQueryToolConfig().__dict__ # pylint: disable=protected-access - - tools = await toolset.get_tools() - assert tools is not None - - assert len(tools) == 3 - assert all(isinstance(tool, GoogleTool) for tool in tools) - - expected_tool_names = set([ - "list_accessible_data_agents", - "get_data_agent_info", - "ask_data_agent", - ]) - actual_tool_names = {tool.name for tool in tools} - assert actual_tool_names == expected_tool_names - - -@pytest.mark.parametrize( - "selected_tools", - [ - pytest.param([], id="None"), - pytest.param( - ["list_accessible_data_agents", "get_data_agent_info"], - id="list_and_get", - ), - pytest.param(["ask_data_agent"], id="ask"), - ], -) -@pytest.mark.asyncio -async def test_bigquery_data_agent_toolset_tools_selective(selected_tools): - """Test BigQueryDataAgentToolset with filter. - - This test verifies the behavior of the BigQueryDataAgentToolset when filter is - specified. A use case for this would be when the agent builder wants to - use only a subset of the tools provided by the toolset. - """ - credentials_config = BigQueryCredentialsConfig( - client_id="abc", client_secret="def" - ) - toolset = BigQueryDataAgentToolset( - credentials_config=credentials_config, tool_filter=selected_tools - ) - tools = await toolset.get_tools() - assert tools is not None - - assert len(tools) == len(selected_tools) - assert all(isinstance(tool, GoogleTool) for tool in tools) - - expected_tool_names = set(selected_tools) - actual_tool_names = {tool.name for tool in tools} - assert actual_tool_names == expected_tool_names - - -@pytest.mark.parametrize( - ("selected_tools", "returned_tools"), - [ - pytest.param(["unknown"], [], id="all-unknown"), - pytest.param( - ["unknown", "ask_data_agent"], - ["ask_data_agent"], - id="mixed-known-unknown", - ), - ], -) -@pytest.mark.asyncio -async def test_bigquery_data_agent_toolset_unknown_tool( - selected_tools, returned_tools -): - """Test BigQueryDataAgentToolset with filter. - - This test verifies the behavior of the BigQueryDataAgentToolset when filter is - specified with an unknown tool. - """ - credentials_config = BigQueryCredentialsConfig( - client_id="abc", client_secret="def" - ) - - toolset = BigQueryDataAgentToolset( - credentials_config=credentials_config, tool_filter=selected_tools - ) - - tools = await toolset.get_tools() - assert tools is not None - - assert len(tools) == len(returned_tools) - assert all(isinstance(tool, GoogleTool) for tool in tools) - - expected_tool_names = set(returned_tools) - actual_tool_names = {tool.name for tool in tools} - assert actual_tool_names == expected_tool_names diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py index 10cc1292..f7d0fa06 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py @@ -16,7 +16,6 @@ import pathlib from unittest import mock from google.adk.tools.bigquery import data_insights_tool -from google.adk.tools.tool_context import ToolContext import pytest import yaml @@ -270,134 +269,3 @@ def test_handle_error(response_dict, expected_output): """Tests the error response handler.""" result = data_insights_tool._handle_error(response_dict) # pylint: disable=protected-access assert result == expected_output - - -@mock.patch.object( - data_insights_tool.geminidataanalytics, "DataAgentServiceClient" -) -def test_list_accessible_data_agents_success(mock_data_agent_client): - """Tests list_accessible_data_agents success path.""" - mock_creds = mock.Mock() - mock_agent1 = mock.MagicMock() - mock_agent1.__str__.return_value = "agent1" - mock_agent2 = mock.MagicMock() - mock_agent2.__str__.return_value = "agent2" - mock_data_agent_client.return_value.list_accessible_data_agents.return_value = [ - mock_agent1, - mock_agent2, - ] - result = data_insights_tool.list_accessible_data_agents( - "test-project", mock_creds - ) - assert result["status"] == "SUCCESS" - assert result["response"] == ["agent1", "agent2"] - mock_data_agent_client.assert_called_once_with(credentials=mock_creds) - - -@mock.patch.object( - data_insights_tool.geminidataanalytics, "DataAgentServiceClient" -) -def test_list_accessible_data_agents_exception(mock_data_agent_client): - """Tests list_accessible_data_agents exception path.""" - mock_creds = mock.Mock() - mock_data_agent_client.return_value.list_accessible_data_agents.side_effect = Exception( - "List failed!" - ) - result = data_insights_tool.list_accessible_data_agents( - "test-project", mock_creds - ) - assert result["status"] == "ERROR" - assert "List failed!" in result["error_details"] - mock_data_agent_client.assert_called_once_with(credentials=mock_creds) - - -@mock.patch.object( - data_insights_tool.geminidataanalytics, "DataAgentServiceClient" -) -def test_get_data_agent_info_success(mock_data_agent_client): - """Tests get_data_agent_info success path.""" - mock_creds = mock.Mock() - mock_response = mock.MagicMock() - mock_response.__str__.return_value = "agent_info" - mock_data_agent_client.return_value.get_data_agent.return_value = ( - mock_response - ) - result = data_insights_tool.get_data_agent_info("agent_name", mock_creds) - assert result["status"] == "SUCCESS" - assert result["response"] == "agent_info" - mock_data_agent_client.assert_called_once_with(credentials=mock_creds) - - -@mock.patch.object( - data_insights_tool.geminidataanalytics, "DataAgentServiceClient" -) -def test_get_data_agent_info_exception(mock_data_agent_client): - """Tests get_data_agent_info exception path.""" - mock_creds = mock.Mock() - mock_data_agent_client.return_value.get_data_agent.side_effect = Exception( - "Get failed!" - ) - result = data_insights_tool.get_data_agent_info("agent_name", mock_creds) - assert result["status"] == "ERROR" - assert "Get failed!" in result["error_details"] - mock_data_agent_client.assert_called_once_with(credentials=mock_creds) - - -@mock.patch.object( - data_insights_tool.geminidataanalytics, "DataChatServiceClient" -) -def test_ask_data_agent_success(mock_data_chat_client): - """Tests ask_data_agent success path.""" - mock_creds = mock.Mock() - mock_invocation_context = mock.Mock() - mock_invocation_context.session.state = {} - mock_context = ToolContext(mock_invocation_context) - mock_response1 = mock.MagicMock() - mock_response1.system_message.text.parts = ["response1"] - mock_response1.system_message.data.generated_sql = None - mock_response1.system_message.data.result = None - mock_response1.system_message.error = None - mock_response2 = mock.MagicMock() - mock_response2.system_message.text.parts = ["response2"] - mock_response2.system_message.data.generated_sql = None - mock_response2.system_message.data.result = None - mock_response2.system_message.error = None - mock_data_chat_client.return_value.chat.return_value = [ - mock_response1, - mock_response2, - ] - result = data_insights_tool.ask_data_agent( - "projects/p/locations/l/dataAgents/a", - "query", - credentials=mock_creds, - tool_context=mock_context, - ) - assert result["status"] == "SUCCESS" - assert result["response"] == [ - {"Answer": "response1"}, - {"Answer": "response2"}, - ] - mock_data_chat_client.assert_called_once_with(credentials=mock_creds) - - -@mock.patch.object( - data_insights_tool.geminidataanalytics, "DataChatServiceClient" -) -def test_ask_data_agent_exception(mock_data_chat_client): - """Tests ask_data_agent exception path.""" - mock_creds = mock.Mock() - mock_invocation_context = mock.Mock() - mock_invocation_context.session.state = {} - mock_context = ToolContext(mock_invocation_context) - mock_data_chat_client.return_value.chat.side_effect = Exception( - "Chat failed!" - ) - result = data_insights_tool.ask_data_agent( - "projects/p/locations/l/dataAgents/a", - "query", - credentials=mock_creds, - tool_context=mock_context, - ) - assert result["status"] == "ERROR" - assert "Chat failed!" in result["error_details"] - mock_data_chat_client.assert_called_once_with(credentials=mock_creds)