feat: add new conversational analytics api tool set

PiperOrigin-RevId: 853489874
This commit is contained in:
Google Team Member
2026-01-07 18:25:31 -08:00
committed by Copybara-Service
parent aaf76a6a51
commit c34feb4c0e
10 changed files with 0 additions and 868 deletions
@@ -1,51 +0,0 @@
# BigQuery Data Agent Sample
This sample agent demonstrates ADK's first-party tools for interacting with
Data Agents based on BigQuery's Conversational Analytics API, distributed via
the `google.adk.tools.bigquery` module. These tools allow you to list,
inspect, and
chat with BigQuery Data Agents using natural language.
These tools leverage stateful conversations, meaning you can ask follow-up
questions in the same session, and the agent will maintain context.
## Prerequisites
1. An active Google Cloud project with BigQuery and Gemini APIs enabled.
2. Google Cloud authentication configured for Application Default Credentials:
```bash
gcloud auth application-default login
```
3. At least one Data Agent created in BigQuery Studio (Data Canvas). These
agents are created and configured in the Google Cloud console and point to
your BigQuery tables or other data sources.
## Tools Used
* `list_accessible_data_agents`: Lists Data Agents you have permission to
access in the configured GCP project.
* `get_data_agent_info`: Retrieves details about a specific Data Agent given
its full resource name.
* `ask_data_agent`: Chats with a specific Data Agent using natural language.
This tool maintains conversation state: if you ask multiple
questions to the same agent in one session, it will use the same
conversation, allowing for follow-ups. If you switch agents, a new
conversation will be started for the new agent.
## How to Run
1. Navigate to the root of the ADK repository.
2. Run the agent using the ADK CLI:
```bash
adk run --agent-path contributing/samples/bigquery_data_agent
```
3. The CLI will prompt you for input. You can ask questions like the examples
below.
## Sample prompts
* "List accessible data agents."
* "Using agent
`projects/my-project/locations/global/dataAgents/sales-agent-123`, who were
my top 3 customers last quarter?"
* "How does that compare to the quarter before?"
@@ -1,15 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -1,90 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from google.adk.agents import Agent
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
from google.adk.tools.bigquery.bigquery_data_agent_toolset import BigQueryDataAgentToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
import google.auth
# Define the desired credential type.
# By default use Application Default Credentials (ADC) from the local
# environment, which can be set up by following
# https://cloud.google.com/docs/authentication/provide-credentials-adc.
CREDENTIALS_TYPE = None
# Define an appropriate application name
BIGQUERY_AGENT_NAME = "adk_sample_bigquery_agent"
# Define BigQuery tool config with write mode set to allowed. Note that this is
# only to demonstrate the full capability of the BigQuery tools. In production
# you may want to change to BLOCKED (default write mode, effectively makes the
# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a
# BigQuery session) write mode.
tool_config = BigQueryToolConfig(
write_mode=WriteMode.ALLOWED, application_name=BIGQUERY_AGENT_NAME
)
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
# Initiaze the tools to do interactive OAuth
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
# must be set
credentials_config = BigQueryCredentialsConfig(
client_id=os.getenv("OAUTH_CLIENT_ID"),
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
)
elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT:
# Initialize the tools to use the credentials in the service account key.
# If this flow is enabled, make sure to replace the file path with your own
# service account key file
# https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys
creds, _ = google.auth.load_credentials_from_file("service_account_key.json")
credentials_config = BigQueryCredentialsConfig(credentials=creds)
else:
# Initialize the tools to use the application default credentials.
# https://cloud.google.com/docs/authentication/provide-credentials-adc
application_default_credentials, _ = google.auth.default()
credentials_config = BigQueryCredentialsConfig(
credentials=application_default_credentials
)
bq_da_toolset = BigQueryDataAgentToolset(
credentials_config=credentials_config,
bigquery_tool_config=tool_config,
tool_filter=[
"list_accessible_data_agents",
"get_data_agent_info",
"ask_data_agent",
],
)
root_agent = Agent(
name="bigquery_data_agent",
model="gemini-2.0-flash",
description="Agent to answer user questions using BigQuery Data Agents.",
instruction=(
"## Persona\nYou are a helpful assistant that uses BigQuery Data Agents"
" to answer user questions about their data.\n\n## Tools\n- You can"
" list available data agents using `list_accessible_data_agents`.\n-"
" You can get information about a specific data agent using"
" `get_data_agent_info`.\n- You can chat with a specific data"
" agent using `ask_data_agent`.\n"
),
tools=[bq_da_toolset],
)
@@ -24,7 +24,6 @@ from ..utils.env_utils import is_env_enabled
class FeatureName(str, Enum):
"""Feature names."""
BIG_QUERY_DATA_AGENT_TOOLSET = "BIG_QUERY_DATA_AGENT_TOOLSET"
BIG_QUERY_TOOLSET = "BIG_QUERY_TOOLSET"
BIG_QUERY_TOOL_CONFIG = "BIG_QUERY_TOOL_CONFIG"
BIGTABLE_TOOL_SETTINGS = "BIGTABLE_TOOL_SETTINGS"
@@ -68,9 +67,6 @@ class FeatureConfig:
# Central registry: FeatureName -> FeatureConfig
_FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
FeatureName.BIG_QUERY_DATA_AGENT_TOOLSET: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
FeatureName.BIG_QUERY_TOOLSET: FeatureConfig(
FeatureStage.EXPERIMENTAL, default_on=True
),
@@ -28,11 +28,9 @@ definition. The rationales to have customized tool are:
"""
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_data_agent_toolset import BigQueryDataAgentToolset
from .bigquery_toolset import BigQueryToolset
__all__ = [
"BigQueryToolset",
"BigQueryCredentialsConfig",
"BigQueryDataAgentToolset",
]
@@ -1,92 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import List
from typing import Optional
from typing import Union
from google.adk.agents.readonly_context import ReadonlyContext
from typing_extensions import override
from . import data_insights_tool
from ...features import experimental
from ...features import FeatureName
from ...tools.base_tool import BaseTool
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from ...tools.google_tool import GoogleTool
from .bigquery_credentials import BigQueryCredentialsConfig
from .config import BigQueryToolConfig
@experimental(FeatureName.BIG_QUERY_DATA_AGENT_TOOLSET)
class BigQueryDataAgentToolset(BaseToolset):
"""BigQuery Data Agent Toolset contains tools for interacting with BigQuery data agents."""
def __init__(
self,
*,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
credentials_config: Optional[BigQueryCredentialsConfig] = None,
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
):
super().__init__(tool_filter=tool_filter)
self._credentials_config = credentials_config
self._tool_settings = (
bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig()
)
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
) -> bool:
if self.tool_filter is None:
return True
if isinstance(self.tool_filter, ToolPredicate):
return self.tool_filter(tool, readonly_context)
if isinstance(self.tool_filter, list):
return tool.name in self.tool_filter
return False
@override
async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> List[BaseTool]:
"""Get tools from the toolset."""
all_tools = [
GoogleTool(
func=func,
credentials_config=self._credentials_config,
tool_settings=self._tool_settings,
)
for func in [
data_insights_tool.list_accessible_data_agents,
data_insights_tool.get_data_agent_info,
data_insights_tool.ask_data_agent,
]
]
return [
tool
for tool in all_tools
if self._is_tool_selected(tool, readonly_context)
]
@override
async def close(self):
pass
@@ -20,11 +20,9 @@ from typing import List
from google.auth.credentials import Credentials
from google.cloud import bigquery
from google.cloud import geminidataanalytics
import requests
from . import client
from ..tool_context import ToolContext
from .config import BigQueryToolConfig
@@ -114,7 +112,6 @@ def ask_data_insights(
]
}
"""
# TODO(huanc): replace this with official client library.
try:
location = "global"
if not credentials.token:
@@ -166,329 +163,9 @@ def ask_data_insights(
return {"status": "SUCCESS", "response": resp}
def list_accessible_data_agents(
project_id: str,
credentials: Credentials,
) -> Dict[str, Any]:
"""Lists accessible data agents in a project.
Args:
project_id: The project to list agents in.
credentials: The credentials to use for the request.
Returns:
A dictionary containing the status and a list of data agents with their
detailed information, including name, display_name, description (if
available), create_time, update_time, and data_analytics_agent context,
or error details if the request fails.
Examples:
>>> list_accessible_data_agents(
... project_id="my-gcp-project",
... credentials=credentials,
... )
{
"status": "SUCCESS",
"response": [
{
"name": "projects/my-project/locations/global/dataAgents/agent1",
"display_name": "My Test Agent",
"create_time": {"seconds": 1759358662, "nanos": 473927629},
"update_time": {"seconds": 1759358663, "nanos": 94541325},
"data_analytics_agent": {
"published_context": {
"datasource_references": [{
"bq": {
"table_references": [{
"project_id": "my-project",
"dataset_id": "dataset1",
"table_id": "table1"
}]
}
}]
}
}
},
{
"name": "projects/my-project/locations/global/dataAgents/agent2",
"display_name": "",
"description": "Description for Agent 2.",
"create_time": {"seconds": 1750710228, "nanos": 650597312},
"update_time": {"seconds": 1750710229, "nanos": 437095391},
"data_analytics_agent": {
"published_context": {
"datasource_references": [{
"bq": {
"table_references": [{
"project_id": "another-project",
"dataset_id": "dataset2",
"table_id": "table2"
}]
}
}],
"system_instruction": "You are a helpful assistant.",
"options": {"analysis": {"python": {"enabled": True}}}
}
}
}
]
}
"""
try:
client = geminidataanalytics.DataAgentServiceClient(credentials=credentials)
request = geminidataanalytics.ListAccessibleDataAgentsRequest(
parent=f"projects/{project_id}/locations/global",
)
page_result = client.list_accessible_data_agents(request=request)
return {
"status": "SUCCESS",
"response": [str(agent) for agent in page_result],
}
except Exception as ex: # pylint: disable=broad-except
return {
"status": "ERROR",
"error_details": str(ex),
}
def get_data_agent_info(
data_agent_name: str,
credentials: Credentials,
) -> Dict[str, Any]:
"""Gets a data agent by name.
Args:
data_agent_name: The name of the agent to get, in format
projects/{project}/locations/{location}/dataAgents/{agent}.
credentials: The credentials to use for the request.
Returns:
A dictionary containing the status and details of a data agent,
including name, display_name, description (if available),
create_time, update_time, and data_analytics_agent context,
or error details if the request fails.
Examples:
>>> get_data_agent_info(
... data_agent_name="projects/my-project/locations/global/dataAgents/agent-1",
... credentials=credentials,
... )
{
"status": "SUCCESS",
"response": {
"name": "projects/my-project/locations/global/dataAgents/agent-1",
"display_name": "My Agent 1",
"description": "Description for Agent 1.",
"create_time": {"seconds": 1750710228, "nanos": 650597312},
"update_time": {"seconds": 1750710229, "nanos": 437095391},
"data_analytics_agent": {
"published_context": {
"datasource_references": [{
"bq": {
"table_references": [{
"project_id": "my-gcp-project",
"dataset_id": "dataset1",
"table_id": "table1"
}]
}
}],
"system_instruction": "You are a helpful assistant.",
"options": {"analysis": {"python": {"enabled": True}}}
}
}
}
}
"""
try:
client = geminidataanalytics.DataAgentServiceClient(credentials=credentials)
request = geminidataanalytics.GetDataAgentRequest(
name=data_agent_name,
)
response = client.get_data_agent(request=request)
return {
"status": "SUCCESS",
"response": str(response),
}
except Exception as ex: # pylint: disable=broad-except
return {
"status": "ERROR",
"error_details": str(ex),
}
def ask_data_agent(
data_agent_name: str,
query: str,
*,
credentials: Credentials,
tool_context: ToolContext,
) -> Dict[str, Any]:
"""Asks a question to a data agent.
Args:
data_agent_name: The resource name of an existing data agent to ask,
in format projects/{project}/locations/{location}/dataAgents/{agent}.
query: The question to ask the agent.
credentials: The credentials to use for the request.
tool_context: The context for the tool.
Returns:
A dictionary with two keys:
- 'status': A string indicating the final status (e.g., "SUCCESS").
- 'response': A list of dictionaries, where each dictionary
represents a step in the agent's execution process (e.g., SQL
generation, data retrieval, final answer). Note that the 'Answer'
step contains a text response which may summarize findings or refer
to previous steps of agent execution, such as 'Data Retrieved' in
which cases, the 'Answer' step does not include the result data.
Examples:
A query to a data agent, showing the full return structure.
The original question: "Which customer from New York spent the most last
month?"
>>> ask_data_agent(
... data_agent_name="projects/my-project/locations/global/dataAgents/sales-agent",
... query="Which customer from New York spent the most last month?",
... credentials=credentials,
... tool_context=tool_context,
... )
{
"status": "SUCCESS",
"response": [
{
"SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... "
},
{
"Data Retrieved": {
"headers": ["customer_name", "total_spent"],
"rows": [["Jane Doe", 1234.56]],
"summary": "Showing all 1 rows."
}
},
{
"Answer": "The customer who spent the most was Jane Doe."
}
]
}
"""
try:
agent_info = get_data_agent_info(data_agent_name, credentials)
if agent_info.get("status") == "ERROR":
return agent_info
client = geminidataanalytics.DataChatServiceClient(credentials=credentials)
parent = data_agent_name.rsplit("/", 2)[0]
conversation_name = None
if (
tool_context.state.get("bigquery_data_agent_conv_agent")
== data_agent_name
):
conversation_name = tool_context.state.get(
"bigquery_data_agent_conv_name"
)
else:
conversation = geminidataanalytics.Conversation()
conversation.agents = [data_agent_name]
request = geminidataanalytics.CreateConversationRequest(
parent=parent,
conversation=conversation,
)
response = client.create_conversation(request=request)
conversation_name = response.name
tool_context.state["bigquery_data_agent_conv_agent"] = data_agent_name
tool_context.state["bigquery_data_agent_conv_name"] = conversation_name
new_user_message = geminidataanalytics.Message()
new_user_message.user_message.text = query
messages = [new_user_message]
if conversation_name:
conversation_reference = geminidataanalytics.ConversationReference()
conversation_reference.conversation = conversation_name
conversation_reference.data_agent_context.data_agent = data_agent_name
request = geminidataanalytics.ChatRequest(
parent=parent,
messages=messages,
conversation_reference=conversation_reference,
)
else:
data_agent_context = geminidataanalytics.DataAgentContext()
data_agent_context.data_agent = data_agent_name
request = geminidataanalytics.ChatRequest(
parent=parent,
messages=messages,
data_agent_context=data_agent_context,
)
stream = client.chat(request=request)
responses = list(stream)
print({
"status": "SUCCESS",
"response": _process_data_agent_stream(responses),
})
return {
"status": "SUCCESS",
"response": _process_data_agent_stream(responses),
}
except Exception as ex: # pylint: disable=broad-except
return {
"status": "ERROR",
"error_details": str(ex),
}
def _process_result_message(result, max_rows: int) -> dict[str, Any]:
"""Processes result message from data agent chat."""
headers = [f.name for f in result.schema.fields]
all_rows_structs = result.data
total_rows = len(all_rows_structs)
summary_string = f"Showing all {total_rows} rows."
if total_rows > max_rows:
summary_string = f"Showing the first {max_rows} of {total_rows} total rows."
rows = []
i = 0
for row in all_rows_structs:
if i >= max_rows:
break
rows.append([row.get(h) for h in headers])
i += 1
return {
"headers": headers,
"rows": rows,
"summary": summary_string,
}
def _process_data_agent_stream(
stream: list[geminidataanalytics.Message], max_rows: int = 1000
) -> list[dict[str, Any]]:
"""Processes stream from data agent chat."""
processed_responses = []
for i, msg in enumerate(stream):
if msg.system_message:
message = msg.system_message
if message.text.parts:
processed_responses.append({"Answer": "".join(message.text.parts)})
elif message.data.generated_sql:
processed_responses.append(
{"SQL Generated": message.data.generated_sql}
)
elif message.data.result and message.data.result.data:
processed_responses.append({
"Data Retrieved": _process_result_message(
message.data.result, max_rows
)
})
elif message.error:
processed_responses.append({"Error": message.error.text})
return processed_responses
def _get_stream(
url: str,
ca_payload: Dict[str, Any],
*,
headers: Dict[str, str],
max_query_result_rows: int,
) -> List[Dict[str, Any]]:
@@ -1,29 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import types
from unittest import mock
MOCK_GEMINI_DATA_ANALYTICS = mock.MagicMock()
sys.modules["google.cloud.geminidataanalytics"] = MOCK_GEMINI_DATA_ANALYTICS
# The mock.patch calls require 'geminidataanalytics' to be an attribute of
# 'google.cloud' module for patching by string to work.
try:
import google.cloud
except ImportError:
sys.modules["google.cloud"] = types.ModuleType("google.cloud")
finally:
sys.modules["google.cloud"].geminidataanalytics = MOCK_GEMINI_DATA_ANALYTICS
@@ -1,130 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from unittest import mock
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryDataAgentToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.google_tool import GoogleTool
import pytest
@pytest.mark.asyncio
async def test_bigquery_data_agent_toolset_tools_default():
"""Test default BigQueryDataAgentToolset.
This test verifies the behavior of the BigQueryDataAgentToolset when no filter is
specified.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryDataAgentToolset(
credentials_config=credentials_config, bigquery_tool_config=None
)
# Verify that the tool config is initialized to default values.
assert isinstance(toolset._tool_settings, BigQueryToolConfig) # pylint: disable=protected-access
assert toolset._tool_settings.__dict__ == BigQueryToolConfig().__dict__ # pylint: disable=protected-access
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 3
assert all(isinstance(tool, GoogleTool) for tool in tools)
expected_tool_names = set([
"list_accessible_data_agents",
"get_data_agent_info",
"ask_data_agent",
])
actual_tool_names = {tool.name for tool in tools}
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
"selected_tools",
[
pytest.param([], id="None"),
pytest.param(
["list_accessible_data_agents", "get_data_agent_info"],
id="list_and_get",
),
pytest.param(["ask_data_agent"], id="ask"),
],
)
@pytest.mark.asyncio
async def test_bigquery_data_agent_toolset_tools_selective(selected_tools):
"""Test BigQueryDataAgentToolset with filter.
This test verifies the behavior of the BigQueryDataAgentToolset when filter is
specified. A use case for this would be when the agent builder wants to
use only a subset of the tools provided by the toolset.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryDataAgentToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(selected_tools)
assert all(isinstance(tool, GoogleTool) for tool in tools)
expected_tool_names = set(selected_tools)
actual_tool_names = {tool.name for tool in tools}
assert actual_tool_names == expected_tool_names
@pytest.mark.parametrize(
("selected_tools", "returned_tools"),
[
pytest.param(["unknown"], [], id="all-unknown"),
pytest.param(
["unknown", "ask_data_agent"],
["ask_data_agent"],
id="mixed-known-unknown",
),
],
)
@pytest.mark.asyncio
async def test_bigquery_data_agent_toolset_unknown_tool(
selected_tools, returned_tools
):
"""Test BigQueryDataAgentToolset with filter.
This test verifies the behavior of the BigQueryDataAgentToolset when filter is
specified with an unknown tool.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryDataAgentToolset(
credentials_config=credentials_config, tool_filter=selected_tools
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == len(returned_tools)
assert all(isinstance(tool, GoogleTool) for tool in tools)
expected_tool_names = set(returned_tools)
actual_tool_names = {tool.name for tool in tools}
assert actual_tool_names == expected_tool_names
@@ -16,7 +16,6 @@ import pathlib
from unittest import mock
from google.adk.tools.bigquery import data_insights_tool
from google.adk.tools.tool_context import ToolContext
import pytest
import yaml
@@ -270,134 +269,3 @@ def test_handle_error(response_dict, expected_output):
"""Tests the error response handler."""
result = data_insights_tool._handle_error(response_dict) # pylint: disable=protected-access
assert result == expected_output
@mock.patch.object(
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
)
def test_list_accessible_data_agents_success(mock_data_agent_client):
"""Tests list_accessible_data_agents success path."""
mock_creds = mock.Mock()
mock_agent1 = mock.MagicMock()
mock_agent1.__str__.return_value = "agent1"
mock_agent2 = mock.MagicMock()
mock_agent2.__str__.return_value = "agent2"
mock_data_agent_client.return_value.list_accessible_data_agents.return_value = [
mock_agent1,
mock_agent2,
]
result = data_insights_tool.list_accessible_data_agents(
"test-project", mock_creds
)
assert result["status"] == "SUCCESS"
assert result["response"] == ["agent1", "agent2"]
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
@mock.patch.object(
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
)
def test_list_accessible_data_agents_exception(mock_data_agent_client):
"""Tests list_accessible_data_agents exception path."""
mock_creds = mock.Mock()
mock_data_agent_client.return_value.list_accessible_data_agents.side_effect = Exception(
"List failed!"
)
result = data_insights_tool.list_accessible_data_agents(
"test-project", mock_creds
)
assert result["status"] == "ERROR"
assert "List failed!" in result["error_details"]
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
@mock.patch.object(
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
)
def test_get_data_agent_info_success(mock_data_agent_client):
"""Tests get_data_agent_info success path."""
mock_creds = mock.Mock()
mock_response = mock.MagicMock()
mock_response.__str__.return_value = "agent_info"
mock_data_agent_client.return_value.get_data_agent.return_value = (
mock_response
)
result = data_insights_tool.get_data_agent_info("agent_name", mock_creds)
assert result["status"] == "SUCCESS"
assert result["response"] == "agent_info"
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
@mock.patch.object(
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
)
def test_get_data_agent_info_exception(mock_data_agent_client):
"""Tests get_data_agent_info exception path."""
mock_creds = mock.Mock()
mock_data_agent_client.return_value.get_data_agent.side_effect = Exception(
"Get failed!"
)
result = data_insights_tool.get_data_agent_info("agent_name", mock_creds)
assert result["status"] == "ERROR"
assert "Get failed!" in result["error_details"]
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
@mock.patch.object(
data_insights_tool.geminidataanalytics, "DataChatServiceClient"
)
def test_ask_data_agent_success(mock_data_chat_client):
"""Tests ask_data_agent success path."""
mock_creds = mock.Mock()
mock_invocation_context = mock.Mock()
mock_invocation_context.session.state = {}
mock_context = ToolContext(mock_invocation_context)
mock_response1 = mock.MagicMock()
mock_response1.system_message.text.parts = ["response1"]
mock_response1.system_message.data.generated_sql = None
mock_response1.system_message.data.result = None
mock_response1.system_message.error = None
mock_response2 = mock.MagicMock()
mock_response2.system_message.text.parts = ["response2"]
mock_response2.system_message.data.generated_sql = None
mock_response2.system_message.data.result = None
mock_response2.system_message.error = None
mock_data_chat_client.return_value.chat.return_value = [
mock_response1,
mock_response2,
]
result = data_insights_tool.ask_data_agent(
"projects/p/locations/l/dataAgents/a",
"query",
credentials=mock_creds,
tool_context=mock_context,
)
assert result["status"] == "SUCCESS"
assert result["response"] == [
{"Answer": "response1"},
{"Answer": "response2"},
]
mock_data_chat_client.assert_called_once_with(credentials=mock_creds)
@mock.patch.object(
data_insights_tool.geminidataanalytics, "DataChatServiceClient"
)
def test_ask_data_agent_exception(mock_data_chat_client):
"""Tests ask_data_agent exception path."""
mock_creds = mock.Mock()
mock_invocation_context = mock.Mock()
mock_invocation_context.session.state = {}
mock_context = ToolContext(mock_invocation_context)
mock_data_chat_client.return_value.chat.side_effect = Exception(
"Chat failed!"
)
result = data_insights_tool.ask_data_agent(
"projects/p/locations/l/dataAgents/a",
"query",
credentials=mock_creds,
tool_context=mock_context,
)
assert result["status"] == "ERROR"
assert "Chat failed!" in result["error_details"]
mock_data_chat_client.assert_called_once_with(credentials=mock_creds)