You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: add new conversational analytics api tool set
PiperOrigin-RevId: 853489874
This commit is contained in:
committed by
Copybara-Service
parent
aaf76a6a51
commit
c34feb4c0e
@@ -1,51 +0,0 @@
|
||||
# BigQuery Data Agent Sample
|
||||
|
||||
This sample agent demonstrates ADK's first-party tools for interacting with
|
||||
Data Agents based on BigQuery's Conversational Analytics API, distributed via
|
||||
the `google.adk.tools.bigquery` module. These tools allow you to list,
|
||||
inspect, and
|
||||
chat with BigQuery Data Agents using natural language.
|
||||
|
||||
These tools leverage stateful conversations, meaning you can ask follow-up
|
||||
questions in the same session, and the agent will maintain context.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. An active Google Cloud project with BigQuery and Gemini APIs enabled.
|
||||
2. Google Cloud authentication configured for Application Default Credentials:
|
||||
```bash
|
||||
gcloud auth application-default login
|
||||
```
|
||||
3. At least one Data Agent created in BigQuery Studio (Data Canvas). These
|
||||
agents are created and configured in the Google Cloud console and point to
|
||||
your BigQuery tables or other data sources.
|
||||
|
||||
## Tools Used
|
||||
|
||||
* `list_accessible_data_agents`: Lists Data Agents you have permission to
|
||||
access in the configured GCP project.
|
||||
* `get_data_agent_info`: Retrieves details about a specific Data Agent given
|
||||
its full resource name.
|
||||
* `ask_data_agent`: Chats with a specific Data Agent using natural language.
|
||||
This tool maintains conversation state: if you ask multiple
|
||||
questions to the same agent in one session, it will use the same
|
||||
conversation, allowing for follow-ups. If you switch agents, a new
|
||||
conversation will be started for the new agent.
|
||||
|
||||
## How to Run
|
||||
|
||||
1. Navigate to the root of the ADK repository.
|
||||
2. Run the agent using the ADK CLI:
|
||||
```bash
|
||||
adk run --agent-path contributing/samples/bigquery_data_agent
|
||||
```
|
||||
3. The CLI will prompt you for input. You can ask questions like the examples
|
||||
below.
|
||||
|
||||
## Sample prompts
|
||||
|
||||
* "List accessible data agents."
|
||||
* "Using agent
|
||||
`projects/my-project/locations/global/dataAgents/sales-agent-123`, who were
|
||||
my top 3 customers last quarter?"
|
||||
* "How does that compare to the quarter before?"
|
||||
@@ -1,15 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -1,90 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.tools.bigquery.bigquery_credentials import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery.bigquery_data_agent_toolset import BigQueryDataAgentToolset
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.bigquery.config import WriteMode
|
||||
import google.auth
|
||||
|
||||
# Define the desired credential type.
|
||||
# By default use Application Default Credentials (ADC) from the local
|
||||
# environment, which can be set up by following
|
||||
# https://cloud.google.com/docs/authentication/provide-credentials-adc.
|
||||
CREDENTIALS_TYPE = None
|
||||
|
||||
# Define an appropriate application name
|
||||
BIGQUERY_AGENT_NAME = "adk_sample_bigquery_agent"
|
||||
|
||||
|
||||
# Define BigQuery tool config with write mode set to allowed. Note that this is
|
||||
# only to demonstrate the full capability of the BigQuery tools. In production
|
||||
# you may want to change to BLOCKED (default write mode, effectively makes the
|
||||
# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a
|
||||
# BigQuery session) write mode.
|
||||
tool_config = BigQueryToolConfig(
|
||||
write_mode=WriteMode.ALLOWED, application_name=BIGQUERY_AGENT_NAME
|
||||
)
|
||||
|
||||
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
|
||||
# Initiaze the tools to do interactive OAuth
|
||||
# The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET
|
||||
# must be set
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id=os.getenv("OAUTH_CLIENT_ID"),
|
||||
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
|
||||
)
|
||||
elif CREDENTIALS_TYPE == AuthCredentialTypes.SERVICE_ACCOUNT:
|
||||
# Initialize the tools to use the credentials in the service account key.
|
||||
# If this flow is enabled, make sure to replace the file path with your own
|
||||
# service account key file
|
||||
# https://cloud.google.com/iam/docs/service-account-creds#user-managed-keys
|
||||
creds, _ = google.auth.load_credentials_from_file("service_account_key.json")
|
||||
credentials_config = BigQueryCredentialsConfig(credentials=creds)
|
||||
else:
|
||||
# Initialize the tools to use the application default credentials.
|
||||
# https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
application_default_credentials, _ = google.auth.default()
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
credentials=application_default_credentials
|
||||
)
|
||||
|
||||
bq_da_toolset = BigQueryDataAgentToolset(
|
||||
credentials_config=credentials_config,
|
||||
bigquery_tool_config=tool_config,
|
||||
tool_filter=[
|
||||
"list_accessible_data_agents",
|
||||
"get_data_agent_info",
|
||||
"ask_data_agent",
|
||||
],
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name="bigquery_data_agent",
|
||||
model="gemini-2.0-flash",
|
||||
description="Agent to answer user questions using BigQuery Data Agents.",
|
||||
instruction=(
|
||||
"## Persona\nYou are a helpful assistant that uses BigQuery Data Agents"
|
||||
" to answer user questions about their data.\n\n## Tools\n- You can"
|
||||
" list available data agents using `list_accessible_data_agents`.\n-"
|
||||
" You can get information about a specific data agent using"
|
||||
" `get_data_agent_info`.\n- You can chat with a specific data"
|
||||
" agent using `ask_data_agent`.\n"
|
||||
),
|
||||
tools=[bq_da_toolset],
|
||||
)
|
||||
@@ -24,7 +24,6 @@ from ..utils.env_utils import is_env_enabled
|
||||
class FeatureName(str, Enum):
|
||||
"""Feature names."""
|
||||
|
||||
BIG_QUERY_DATA_AGENT_TOOLSET = "BIG_QUERY_DATA_AGENT_TOOLSET"
|
||||
BIG_QUERY_TOOLSET = "BIG_QUERY_TOOLSET"
|
||||
BIG_QUERY_TOOL_CONFIG = "BIG_QUERY_TOOL_CONFIG"
|
||||
BIGTABLE_TOOL_SETTINGS = "BIGTABLE_TOOL_SETTINGS"
|
||||
@@ -68,9 +67,6 @@ class FeatureConfig:
|
||||
|
||||
# Central registry: FeatureName -> FeatureConfig
|
||||
_FEATURE_REGISTRY: dict[FeatureName, FeatureConfig] = {
|
||||
FeatureName.BIG_QUERY_DATA_AGENT_TOOLSET: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
FeatureName.BIG_QUERY_TOOLSET: FeatureConfig(
|
||||
FeatureStage.EXPERIMENTAL, default_on=True
|
||||
),
|
||||
|
||||
@@ -28,11 +28,9 @@ definition. The rationales to have customized tool are:
|
||||
"""
|
||||
|
||||
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||
from .bigquery_data_agent_toolset import BigQueryDataAgentToolset
|
||||
from .bigquery_toolset import BigQueryToolset
|
||||
|
||||
__all__ = [
|
||||
"BigQueryToolset",
|
||||
"BigQueryCredentialsConfig",
|
||||
"BigQueryDataAgentToolset",
|
||||
]
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
from typing_extensions import override
|
||||
|
||||
from . import data_insights_tool
|
||||
from ...features import experimental
|
||||
from ...features import FeatureName
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.base_toolset import BaseToolset
|
||||
from ...tools.base_toolset import ToolPredicate
|
||||
from ...tools.google_tool import GoogleTool
|
||||
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||
from .config import BigQueryToolConfig
|
||||
|
||||
|
||||
@experimental(FeatureName.BIG_QUERY_DATA_AGENT_TOOLSET)
|
||||
class BigQueryDataAgentToolset(BaseToolset):
|
||||
"""BigQuery Data Agent Toolset contains tools for interacting with BigQuery data agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
credentials_config: Optional[BigQueryCredentialsConfig] = None,
|
||||
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
|
||||
):
|
||||
super().__init__(tool_filter=tool_filter)
|
||||
self._credentials_config = credentials_config
|
||||
self._tool_settings = (
|
||||
bigquery_tool_config if bigquery_tool_config else BigQueryToolConfig()
|
||||
)
|
||||
|
||||
def _is_tool_selected(
|
||||
self, tool: BaseTool, readonly_context: ReadonlyContext
|
||||
) -> bool:
|
||||
if self.tool_filter is None:
|
||||
return True
|
||||
|
||||
if isinstance(self.tool_filter, ToolPredicate):
|
||||
return self.tool_filter(tool, readonly_context)
|
||||
|
||||
if isinstance(self.tool_filter, list):
|
||||
return tool.name in self.tool_filter
|
||||
|
||||
return False
|
||||
|
||||
@override
|
||||
async def get_tools(
|
||||
self, readonly_context: Optional[ReadonlyContext] = None
|
||||
) -> List[BaseTool]:
|
||||
"""Get tools from the toolset."""
|
||||
all_tools = [
|
||||
GoogleTool(
|
||||
func=func,
|
||||
credentials_config=self._credentials_config,
|
||||
tool_settings=self._tool_settings,
|
||||
)
|
||||
for func in [
|
||||
data_insights_tool.list_accessible_data_agents,
|
||||
data_insights_tool.get_data_agent_info,
|
||||
data_insights_tool.ask_data_agent,
|
||||
]
|
||||
]
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if self._is_tool_selected(tool, readonly_context)
|
||||
]
|
||||
|
||||
@override
|
||||
async def close(self):
|
||||
pass
|
||||
@@ -20,11 +20,9 @@ from typing import List
|
||||
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
from google.cloud import geminidataanalytics
|
||||
import requests
|
||||
|
||||
from . import client
|
||||
from ..tool_context import ToolContext
|
||||
from .config import BigQueryToolConfig
|
||||
|
||||
|
||||
@@ -114,7 +112,6 @@ def ask_data_insights(
|
||||
]
|
||||
}
|
||||
"""
|
||||
# TODO(huanc): replace this with official client library.
|
||||
try:
|
||||
location = "global"
|
||||
if not credentials.token:
|
||||
@@ -166,329 +163,9 @@ def ask_data_insights(
|
||||
return {"status": "SUCCESS", "response": resp}
|
||||
|
||||
|
||||
def list_accessible_data_agents(
|
||||
project_id: str,
|
||||
credentials: Credentials,
|
||||
) -> Dict[str, Any]:
|
||||
"""Lists accessible data agents in a project.
|
||||
|
||||
Args:
|
||||
project_id: The project to list agents in.
|
||||
credentials: The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the status and a list of data agents with their
|
||||
detailed information, including name, display_name, description (if
|
||||
available), create_time, update_time, and data_analytics_agent context,
|
||||
or error details if the request fails.
|
||||
|
||||
Examples:
|
||||
>>> list_accessible_data_agents(
|
||||
... project_id="my-gcp-project",
|
||||
... credentials=credentials,
|
||||
... )
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"response": [
|
||||
{
|
||||
"name": "projects/my-project/locations/global/dataAgents/agent1",
|
||||
"display_name": "My Test Agent",
|
||||
"create_time": {"seconds": 1759358662, "nanos": 473927629},
|
||||
"update_time": {"seconds": 1759358663, "nanos": 94541325},
|
||||
"data_analytics_agent": {
|
||||
"published_context": {
|
||||
"datasource_references": [{
|
||||
"bq": {
|
||||
"table_references": [{
|
||||
"project_id": "my-project",
|
||||
"dataset_id": "dataset1",
|
||||
"table_id": "table1"
|
||||
}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "projects/my-project/locations/global/dataAgents/agent2",
|
||||
"display_name": "",
|
||||
"description": "Description for Agent 2.",
|
||||
"create_time": {"seconds": 1750710228, "nanos": 650597312},
|
||||
"update_time": {"seconds": 1750710229, "nanos": 437095391},
|
||||
"data_analytics_agent": {
|
||||
"published_context": {
|
||||
"datasource_references": [{
|
||||
"bq": {
|
||||
"table_references": [{
|
||||
"project_id": "another-project",
|
||||
"dataset_id": "dataset2",
|
||||
"table_id": "table2"
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"system_instruction": "You are a helpful assistant.",
|
||||
"options": {"analysis": {"python": {"enabled": True}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
client = geminidataanalytics.DataAgentServiceClient(credentials=credentials)
|
||||
request = geminidataanalytics.ListAccessibleDataAgentsRequest(
|
||||
parent=f"projects/{project_id}/locations/global",
|
||||
)
|
||||
page_result = client.list_accessible_data_agents(request=request)
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"response": [str(agent) for agent in page_result],
|
||||
}
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
def get_data_agent_info(
|
||||
data_agent_name: str,
|
||||
credentials: Credentials,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gets a data agent by name.
|
||||
|
||||
Args:
|
||||
data_agent_name: The name of the agent to get, in format
|
||||
projects/{project}/locations/{location}/dataAgents/{agent}.
|
||||
credentials: The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the status and details of a data agent,
|
||||
including name, display_name, description (if available),
|
||||
create_time, update_time, and data_analytics_agent context,
|
||||
or error details if the request fails.
|
||||
|
||||
Examples:
|
||||
>>> get_data_agent_info(
|
||||
... data_agent_name="projects/my-project/locations/global/dataAgents/agent-1",
|
||||
... credentials=credentials,
|
||||
... )
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"response": {
|
||||
"name": "projects/my-project/locations/global/dataAgents/agent-1",
|
||||
"display_name": "My Agent 1",
|
||||
"description": "Description for Agent 1.",
|
||||
"create_time": {"seconds": 1750710228, "nanos": 650597312},
|
||||
"update_time": {"seconds": 1750710229, "nanos": 437095391},
|
||||
"data_analytics_agent": {
|
||||
"published_context": {
|
||||
"datasource_references": [{
|
||||
"bq": {
|
||||
"table_references": [{
|
||||
"project_id": "my-gcp-project",
|
||||
"dataset_id": "dataset1",
|
||||
"table_id": "table1"
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"system_instruction": "You are a helpful assistant.",
|
||||
"options": {"analysis": {"python": {"enabled": True}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
client = geminidataanalytics.DataAgentServiceClient(credentials=credentials)
|
||||
request = geminidataanalytics.GetDataAgentRequest(
|
||||
name=data_agent_name,
|
||||
)
|
||||
response = client.get_data_agent(request=request)
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"response": str(response),
|
||||
}
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
def ask_data_agent(
|
||||
data_agent_name: str,
|
||||
query: str,
|
||||
*,
|
||||
credentials: Credentials,
|
||||
tool_context: ToolContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""Asks a question to a data agent.
|
||||
|
||||
Args:
|
||||
data_agent_name: The resource name of an existing data agent to ask,
|
||||
in format projects/{project}/locations/{location}/dataAgents/{agent}.
|
||||
query: The question to ask the agent.
|
||||
credentials: The credentials to use for the request.
|
||||
tool_context: The context for the tool.
|
||||
|
||||
Returns:
|
||||
A dictionary with two keys:
|
||||
- 'status': A string indicating the final status (e.g., "SUCCESS").
|
||||
- 'response': A list of dictionaries, where each dictionary
|
||||
represents a step in the agent's execution process (e.g., SQL
|
||||
generation, data retrieval, final answer). Note that the 'Answer'
|
||||
step contains a text response which may summarize findings or refer
|
||||
to previous steps of agent execution, such as 'Data Retrieved', in
|
||||
which cases, the 'Answer' step does not include the result data.
|
||||
|
||||
Examples:
|
||||
A query to a data agent, showing the full return structure.
|
||||
The original question: "Which customer from New York spent the most last
|
||||
month?"
|
||||
|
||||
>>> ask_data_agent(
|
||||
... data_agent_name="projects/my-project/locations/global/dataAgents/sales-agent",
|
||||
... query="Which customer from New York spent the most last month?",
|
||||
... credentials=credentials,
|
||||
... tool_context=tool_context,
|
||||
... )
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"response": [
|
||||
{
|
||||
"SQL Generated": "SELECT t1.customer_name, SUM(t2.order_total) ... "
|
||||
},
|
||||
{
|
||||
"Data Retrieved": {
|
||||
"headers": ["customer_name", "total_spent"],
|
||||
"rows": [["Jane Doe", 1234.56]],
|
||||
"summary": "Showing all 1 rows."
|
||||
}
|
||||
},
|
||||
{
|
||||
"Answer": "The customer who spent the most was Jane Doe."
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
agent_info = get_data_agent_info(data_agent_name, credentials)
|
||||
if agent_info.get("status") == "ERROR":
|
||||
return agent_info
|
||||
client = geminidataanalytics.DataChatServiceClient(credentials=credentials)
|
||||
parent = data_agent_name.rsplit("/", 2)[0]
|
||||
conversation_name = None
|
||||
|
||||
if (
|
||||
tool_context.state.get("bigquery_data_agent_conv_agent")
|
||||
== data_agent_name
|
||||
):
|
||||
conversation_name = tool_context.state.get(
|
||||
"bigquery_data_agent_conv_name"
|
||||
)
|
||||
else:
|
||||
conversation = geminidataanalytics.Conversation()
|
||||
conversation.agents = [data_agent_name]
|
||||
request = geminidataanalytics.CreateConversationRequest(
|
||||
parent=parent,
|
||||
conversation=conversation,
|
||||
)
|
||||
response = client.create_conversation(request=request)
|
||||
conversation_name = response.name
|
||||
tool_context.state["bigquery_data_agent_conv_agent"] = data_agent_name
|
||||
tool_context.state["bigquery_data_agent_conv_name"] = conversation_name
|
||||
|
||||
new_user_message = geminidataanalytics.Message()
|
||||
new_user_message.user_message.text = query
|
||||
messages = [new_user_message]
|
||||
|
||||
if conversation_name:
|
||||
conversation_reference = geminidataanalytics.ConversationReference()
|
||||
conversation_reference.conversation = conversation_name
|
||||
conversation_reference.data_agent_context.data_agent = data_agent_name
|
||||
request = geminidataanalytics.ChatRequest(
|
||||
parent=parent,
|
||||
messages=messages,
|
||||
conversation_reference=conversation_reference,
|
||||
)
|
||||
else:
|
||||
data_agent_context = geminidataanalytics.DataAgentContext()
|
||||
data_agent_context.data_agent = data_agent_name
|
||||
request = geminidataanalytics.ChatRequest(
|
||||
parent=parent,
|
||||
messages=messages,
|
||||
data_agent_context=data_agent_context,
|
||||
)
|
||||
stream = client.chat(request=request)
|
||||
responses = list(stream)
|
||||
print({
|
||||
"status": "SUCCESS",
|
||||
"response": _process_data_agent_stream(responses),
|
||||
})
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"response": _process_data_agent_stream(responses),
|
||||
}
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
def _process_result_message(result, max_rows: int) -> dict[str, Any]:
|
||||
"""Processes result message from data agent chat."""
|
||||
headers = [f.name for f in result.schema.fields]
|
||||
all_rows_structs = result.data
|
||||
total_rows = len(all_rows_structs)
|
||||
|
||||
summary_string = f"Showing all {total_rows} rows."
|
||||
if total_rows > max_rows:
|
||||
summary_string = f"Showing the first {max_rows} of {total_rows} total rows."
|
||||
rows = []
|
||||
i = 0
|
||||
for row in all_rows_structs:
|
||||
if i >= max_rows:
|
||||
break
|
||||
rows.append([row.get(h) for h in headers])
|
||||
i += 1
|
||||
return {
|
||||
"headers": headers,
|
||||
"rows": rows,
|
||||
"summary": summary_string,
|
||||
}
|
||||
|
||||
|
||||
def _process_data_agent_stream(
|
||||
stream: list[geminidataanalytics.Message], max_rows: int = 1000
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Processes stream from data agent chat."""
|
||||
processed_responses = []
|
||||
for i, msg in enumerate(stream):
|
||||
if msg.system_message:
|
||||
message = msg.system_message
|
||||
if message.text.parts:
|
||||
processed_responses.append({"Answer": "".join(message.text.parts)})
|
||||
elif message.data.generated_sql:
|
||||
processed_responses.append(
|
||||
{"SQL Generated": message.data.generated_sql}
|
||||
)
|
||||
elif message.data.result and message.data.result.data:
|
||||
processed_responses.append({
|
||||
"Data Retrieved": _process_result_message(
|
||||
message.data.result, max_rows
|
||||
)
|
||||
})
|
||||
elif message.error:
|
||||
processed_responses.append({"Error": message.error.text})
|
||||
return processed_responses
|
||||
|
||||
|
||||
def _get_stream(
|
||||
url: str,
|
||||
ca_payload: Dict[str, Any],
|
||||
*,
|
||||
headers: Dict[str, str],
|
||||
max_query_result_rows: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import types
|
||||
from unittest import mock
|
||||
|
||||
MOCK_GEMINI_DATA_ANALYTICS = mock.MagicMock()
|
||||
sys.modules["google.cloud.geminidataanalytics"] = MOCK_GEMINI_DATA_ANALYTICS
|
||||
|
||||
# The mock.patch calls require 'geminidataanalytics' to be an attribute of
|
||||
# 'google.cloud' module for patching by string to work.
|
||||
try:
|
||||
import google.cloud
|
||||
except ImportError:
|
||||
sys.modules["google.cloud"] = types.ModuleType("google.cloud")
|
||||
finally:
|
||||
sys.modules["google.cloud"].geminidataanalytics = MOCK_GEMINI_DATA_ANALYTICS
|
||||
@@ -1,130 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryDataAgentToolset
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.google_tool import GoogleTool
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigquery_data_agent_toolset_tools_default():
|
||||
"""Test default BigQueryDataAgentToolset.
|
||||
|
||||
This test verifies the behavior of the BigQueryDataAgentToolset when no filter is
|
||||
specified.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
toolset = BigQueryDataAgentToolset(
|
||||
credentials_config=credentials_config, bigquery_tool_config=None
|
||||
)
|
||||
# Verify that the tool config is initialized to default values.
|
||||
assert isinstance(toolset._tool_settings, BigQueryToolConfig) # pylint: disable=protected-access
|
||||
assert toolset._tool_settings.__dict__ == BigQueryToolConfig().__dict__ # pylint: disable=protected-access
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == 3
|
||||
assert all(isinstance(tool, GoogleTool) for tool in tools)
|
||||
|
||||
expected_tool_names = set([
|
||||
"list_accessible_data_agents",
|
||||
"get_data_agent_info",
|
||||
"ask_data_agent",
|
||||
])
|
||||
actual_tool_names = {tool.name for tool in tools}
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"selected_tools",
|
||||
[
|
||||
pytest.param([], id="None"),
|
||||
pytest.param(
|
||||
["list_accessible_data_agents", "get_data_agent_info"],
|
||||
id="list_and_get",
|
||||
),
|
||||
pytest.param(["ask_data_agent"], id="ask"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigquery_data_agent_toolset_tools_selective(selected_tools):
|
||||
"""Test BigQueryDataAgentToolset with filter.
|
||||
|
||||
This test verifies the behavior of the BigQueryDataAgentToolset when filter is
|
||||
specified. A use case for this would be when the agent builder wants to
|
||||
use only a subset of the tools provided by the toolset.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
toolset = BigQueryDataAgentToolset(
|
||||
credentials_config=credentials_config, tool_filter=selected_tools
|
||||
)
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == len(selected_tools)
|
||||
assert all(isinstance(tool, GoogleTool) for tool in tools)
|
||||
|
||||
expected_tool_names = set(selected_tools)
|
||||
actual_tool_names = {tool.name for tool in tools}
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("selected_tools", "returned_tools"),
|
||||
[
|
||||
pytest.param(["unknown"], [], id="all-unknown"),
|
||||
pytest.param(
|
||||
["unknown", "ask_data_agent"],
|
||||
["ask_data_agent"],
|
||||
id="mixed-known-unknown",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigquery_data_agent_toolset_unknown_tool(
|
||||
selected_tools, returned_tools
|
||||
):
|
||||
"""Test BigQueryDataAgentToolset with filter.
|
||||
|
||||
This test verifies the behavior of the BigQueryDataAgentToolset when filter is
|
||||
specified with an unknown tool.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
|
||||
toolset = BigQueryDataAgentToolset(
|
||||
credentials_config=credentials_config, tool_filter=selected_tools
|
||||
)
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == len(returned_tools)
|
||||
assert all(isinstance(tool, GoogleTool) for tool in tools)
|
||||
|
||||
expected_tool_names = set(returned_tools)
|
||||
actual_tool_names = {tool.name for tool in tools}
|
||||
assert actual_tool_names == expected_tool_names
|
||||
@@ -16,7 +16,6 @@ import pathlib
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.bigquery import data_insights_tool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
@@ -270,134 +269,3 @@ def test_handle_error(response_dict, expected_output):
|
||||
"""Tests the error response handler."""
|
||||
result = data_insights_tool._handle_error(response_dict) # pylint: disable=protected-access
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
|
||||
)
|
||||
def test_list_accessible_data_agents_success(mock_data_agent_client):
|
||||
"""Tests list_accessible_data_agents success path."""
|
||||
mock_creds = mock.Mock()
|
||||
mock_agent1 = mock.MagicMock()
|
||||
mock_agent1.__str__.return_value = "agent1"
|
||||
mock_agent2 = mock.MagicMock()
|
||||
mock_agent2.__str__.return_value = "agent2"
|
||||
mock_data_agent_client.return_value.list_accessible_data_agents.return_value = [
|
||||
mock_agent1,
|
||||
mock_agent2,
|
||||
]
|
||||
result = data_insights_tool.list_accessible_data_agents(
|
||||
"test-project", mock_creds
|
||||
)
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["response"] == ["agent1", "agent2"]
|
||||
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
|
||||
)
|
||||
def test_list_accessible_data_agents_exception(mock_data_agent_client):
|
||||
"""Tests list_accessible_data_agents exception path."""
|
||||
mock_creds = mock.Mock()
|
||||
mock_data_agent_client.return_value.list_accessible_data_agents.side_effect = Exception(
|
||||
"List failed!"
|
||||
)
|
||||
result = data_insights_tool.list_accessible_data_agents(
|
||||
"test-project", mock_creds
|
||||
)
|
||||
assert result["status"] == "ERROR"
|
||||
assert "List failed!" in result["error_details"]
|
||||
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
|
||||
)
|
||||
def test_get_data_agent_info_success(mock_data_agent_client):
|
||||
"""Tests get_data_agent_info success path."""
|
||||
mock_creds = mock.Mock()
|
||||
mock_response = mock.MagicMock()
|
||||
mock_response.__str__.return_value = "agent_info"
|
||||
mock_data_agent_client.return_value.get_data_agent.return_value = (
|
||||
mock_response
|
||||
)
|
||||
result = data_insights_tool.get_data_agent_info("agent_name", mock_creds)
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["response"] == "agent_info"
|
||||
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_insights_tool.geminidataanalytics, "DataAgentServiceClient"
|
||||
)
|
||||
def test_get_data_agent_info_exception(mock_data_agent_client):
|
||||
"""Tests get_data_agent_info exception path."""
|
||||
mock_creds = mock.Mock()
|
||||
mock_data_agent_client.return_value.get_data_agent.side_effect = Exception(
|
||||
"Get failed!"
|
||||
)
|
||||
result = data_insights_tool.get_data_agent_info("agent_name", mock_creds)
|
||||
assert result["status"] == "ERROR"
|
||||
assert "Get failed!" in result["error_details"]
|
||||
mock_data_agent_client.assert_called_once_with(credentials=mock_creds)
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_insights_tool.geminidataanalytics, "DataChatServiceClient"
|
||||
)
|
||||
def test_ask_data_agent_success(mock_data_chat_client):
|
||||
"""Tests ask_data_agent success path."""
|
||||
mock_creds = mock.Mock()
|
||||
mock_invocation_context = mock.Mock()
|
||||
mock_invocation_context.session.state = {}
|
||||
mock_context = ToolContext(mock_invocation_context)
|
||||
mock_response1 = mock.MagicMock()
|
||||
mock_response1.system_message.text.parts = ["response1"]
|
||||
mock_response1.system_message.data.generated_sql = None
|
||||
mock_response1.system_message.data.result = None
|
||||
mock_response1.system_message.error = None
|
||||
mock_response2 = mock.MagicMock()
|
||||
mock_response2.system_message.text.parts = ["response2"]
|
||||
mock_response2.system_message.data.generated_sql = None
|
||||
mock_response2.system_message.data.result = None
|
||||
mock_response2.system_message.error = None
|
||||
mock_data_chat_client.return_value.chat.return_value = [
|
||||
mock_response1,
|
||||
mock_response2,
|
||||
]
|
||||
result = data_insights_tool.ask_data_agent(
|
||||
"projects/p/locations/l/dataAgents/a",
|
||||
"query",
|
||||
credentials=mock_creds,
|
||||
tool_context=mock_context,
|
||||
)
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert result["response"] == [
|
||||
{"Answer": "response1"},
|
||||
{"Answer": "response2"},
|
||||
]
|
||||
mock_data_chat_client.assert_called_once_with(credentials=mock_creds)
|
||||
|
||||
|
||||
@mock.patch.object(
|
||||
data_insights_tool.geminidataanalytics, "DataChatServiceClient"
|
||||
)
|
||||
def test_ask_data_agent_exception(mock_data_chat_client):
|
||||
"""Tests ask_data_agent exception path."""
|
||||
mock_creds = mock.Mock()
|
||||
mock_invocation_context = mock.Mock()
|
||||
mock_invocation_context.session.state = {}
|
||||
mock_context = ToolContext(mock_invocation_context)
|
||||
mock_data_chat_client.return_value.chat.side_effect = Exception(
|
||||
"Chat failed!"
|
||||
)
|
||||
result = data_insights_tool.ask_data_agent(
|
||||
"projects/p/locations/l/dataAgents/a",
|
||||
"query",
|
||||
credentials=mock_creds,
|
||||
tool_context=mock_context,
|
||||
)
|
||||
assert result["status"] == "ERROR"
|
||||
assert "Chat failed!" in result["error_details"]
|
||||
mock_data_chat_client.assert_called_once_with(credentials=mock_creds)
|
||||
|
||||
Reference in New Issue
Block a user