feat: Add BigQuery analyze_contribution tool

This change introduces a new `analyze_contribution` function in `query_tool.py` which uses BigQuery ML's `CREATE MODEL` with `CONTRIBUTION_ANALYSIS` type and `ML.GET_INSIGHTS` to analyze the contribution of different dimensions to a given metric. The new function is also added to the `bigquery_toolset`.

PiperOrigin-RevId: 815849281
This commit is contained in:
Haoming Chen
2025-10-06 12:58:34 -07:00
committed by Copybara-Service
parent 5c6cdcd197
commit 4bb089d386
4 changed files with 333 additions and 2 deletions
@@ -82,6 +82,7 @@ class BigQueryToolset(BaseToolset):
metadata_tool.list_table_ids,
query_tool.get_execute_sql(self._tool_settings),
query_tool.forecast,
query_tool.analyze_contribution,
data_insights_tool.ask_data_insights,
]
]
+200
View File
@@ -19,6 +19,7 @@ import json
import types
from typing import Callable
from typing import Optional
import uuid
from google.auth.credentials import Credentials
from google.cloud import bigquery
@@ -892,3 +893,202 @@ def forecast(
)
"""
return execute_sql(project_id, query, credentials, settings, tool_context)
def analyze_contribution(
project_id: str,
input_data: str,
contribution_metric: str,
dimension_id_cols: list[str],
is_test_col: str,
credentials: Credentials,
settings: BigQueryToolConfig,
tool_context: ToolContext,
top_k_insights: int = 30,
pruning_method: str = "PRUNE_REDUNDANT_INSIGHTS",
) -> dict:
"""Run a BigQuery ML contribution analysis using ML.CREATE_MODEL and ML.GET_INSIGHTS.
Args:
project_id (str): The GCP project id in which the query should be
executed.
input_data (str): The data that contain the test and control data to
analyze. Can be a fully qualified BigQuery table ID or a SQL query.
dimension_id_cols (list[str]): The column names of the dimension columns.
contribution_metric (str): The name of the column that contains the metric
to analyze. Provides the expression to use to calculate the metric you
are analyzing. To calculate a summable metric, the expression must be in
the form SUM(metric_column_name), where metric_column_name is a numeric
data type. To calculate a summable ratio metric, the expression must be
in the form
SUM(numerator_metric_column_name)/SUM(denominator_metric_column_name),
where numerator_metric_column_name and denominator_metric_column_name
are numeric data types. To calculate a summable by category metric, the
expression must be in the form
SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name). The
summed column must be a numeric data type. The categorical column must
have type BOOL, DATE, DATETIME, TIME, TIMESTAMP, STRING, or INT64.
is_test_col (str): The name of the column to use to determine whether a
given row is test data or control data. The column must have a BOOL data
type.
credentials: The credentials to use for the request.
settings: The settings for the tool.
tool_context: The context for the tool.
top_k_insights (int, optional): The number of top insights to return,
ranked by apriori support. Defaults to 30.
pruning_method (str, optional): The method to use for pruning redundant
insights. Can be 'NO_PRUNING' or 'PRUNE_REDUNDANT_INSIGHTS'. Defaults to
"PRUNE_REDUNDANT_INSIGHTS".
Returns:
dict: Dictionary representing the result of the contribution analysis.
Examples:
Analyze the contribution of different dimensions to the total sales:
>>> analyze_contribution(
... project_id="my-gcp-project",
... input_data="my-dataset.my-sales-table",
... dimension_id_cols=["store_id", "product_category"],
... contribution_metric="SUM(total_sales)",
... is_test_col="is_test"
... )
The return is:
{
"status": "SUCCESS",
"rows": [
{
"store_id": "S1",
"product_category": "Electronics",
"contributors": ["S1", "Electronics"],
"metric_test": 120,
"metric_control": 100,
"difference": 20,
"relative_difference": 0.2,
"unexpected_difference": 5,
"relative_unexpected_difference": 0.043,
"apriori_support": 0.15
},
...
]
}
Analyze the contribution of different dimensions to the total sales using
a SQL query as input:
>>> analyze_contribution(
... project_id="my-gcp-project",
... input_data="SELECT store_id, product_category, total_sales, "
... "is_test FROM `my-project.my-dataset.my-sales-table` "
... "WHERE transaction_date > '2025-01-01'"
... dimension_id_cols=["store_id", "product_category"],
... contribution_metric="SUM(total_sales)",
... is_test_col="is_test"
... )
The return is:
{
"status": "SUCCESS",
"rows": [
{
"store_id": "S2",
"product_category": "Groceries",
"contributors": ["S2", "Groceries"],
"metric_test": 250,
"metric_control": 200,
"difference": 50,
"relative_difference": 0.25,
"unexpected_difference": 10,
"relative_unexpected_difference": 0.041,
"apriori_support": 0.22
},
...
]
}
"""
if not all(isinstance(item, str) for item in dimension_id_cols):
return {
"status": "ERROR",
"error_details": "All elements in dimension_id_cols must be strings.",
}
# Generate a unique temporary model name
model_name = (
f"contribution_analysis_model_{str(uuid.uuid4()).replace('-', '_')}"
)
id_cols_str = "[" + ", ".join([f"'{col}'" for col in dimension_id_cols]) + "]"
options = [
"MODEL_TYPE = 'CONTRIBUTION_ANALYSIS'",
f"CONTRIBUTION_METRIC = '{contribution_metric}'",
f"IS_TEST_COL = '{is_test_col}'",
f"DIMENSION_ID_COLS = {id_cols_str}",
]
options.append(f"TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = {top_k_insights}")
upper_pruning = pruning_method.upper()
if upper_pruning not in ["NO_PRUNING", "PRUNE_REDUNDANT_INSIGHTS"]:
return {
"status": "ERROR",
"error_details": f"Invalid pruning_method: {pruning_method}",
}
options.append(f"PRUNING_METHOD = '{upper_pruning}'")
options_str = ", ".join(options)
trimmed_upper_input_data = input_data.strip().upper()
if trimmed_upper_input_data.startswith(
"SELECT"
) or trimmed_upper_input_data.startswith("WITH"):
input_data_source = f"({input_data})"
else:
input_data_source = f"SELECT * FROM `{input_data}`"
create_model_query = f"""
CREATE TEMP MODEL {model_name}
OPTIONS ({options_str})
AS {input_data_source}
"""
get_insights_query = f"""
SELECT * FROM ML.GET_INSIGHTS(MODEL {model_name})
"""
# Create a session and run the create model query.
original_write_mode = settings.write_mode
try:
if settings.write_mode == WriteMode.BLOCKED:
raise ValueError("analyze_contribution is not allowed in this session.")
elif original_write_mode != WriteMode.PROTECTED:
# Running create temp model requires a session. So we set the write mode
# to PROTECTED to run the create model query and job query in the same
# session.
settings.write_mode = WriteMode.PROTECTED
result = execute_sql(
project_id,
create_model_query,
credentials,
settings,
tool_context,
)
if result["status"] != "SUCCESS":
return result
result = execute_sql(
project_id,
get_insights_query,
credentials,
settings,
tool_context,
)
except Exception as ex: # pylint: disable=broad-except
return {
"status": "ERROR",
"error_details": f"Error during analyze_contribution: {str(ex)}",
}
finally:
# Restore the original write mode.
settings.write_mode == original_write_mode
return result
@@ -28,6 +28,7 @@ from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
from google.adk.tools.bigquery.query_tool import analyze_contribution
from google.adk.tools.bigquery.query_tool import execute_sql
from google.adk.tools.bigquery.query_tool import forecast
from google.adk.tools.tool_context import ToolContext
@@ -874,7 +875,8 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
"""Test execute_sql tool for non-SELECT query when writes are protected.
This is a special case when the destination table is a persistent/permananent
one and the protected write is enabled. In this case the operation should fail.
one and the protected write is enabled. In this case the operation should
fail.
"""
project = "my_project"
query_result = []
@@ -1272,3 +1274,130 @@ def test_forecast_with_invalid_id_cols():
assert result["status"] == "ERROR"
assert "All elements in id_cols must be strings." in result["error_details"]
# analyze_contribution calls execute_sql twice. We need to test that the
# queries are properly constructed and call execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
"""Test analyze_contribution tool invocation with a table id."""
mock_credentials = mock.MagicMock(spec=Credentials)
mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", "dim2"],
contribution_metric="SUM(metric)",
is_test_col="is_test",
credentials=mock_credentials,
settings=mock_settings,
tool_context=mock_tool_context,
)
expected_create_model_query = """
CREATE TEMP MODEL contribution_analysis_model_test_uuid
OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = 'SUM(metric)', IS_TEST_COL = 'is_test', DIMENSION_ID_COLS = ['dim1', 'dim2'], TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = 30, PRUNING_METHOD = 'PRUNE_REDUNDANT_INSIGHTS')
AS SELECT * FROM `test-dataset.test-table`
"""
expected_get_insights_query = """
SELECT * FROM ML.GET_INSIGHTS(MODEL contribution_analysis_model_test_uuid)
"""
assert mock_execute_sql.call_count == 2
mock_execute_sql.assert_any_call(
"test-project",
expected_create_model_query,
mock_credentials,
mock_settings,
mock_tool_context,
)
mock_execute_sql.assert_any_call(
"test-project",
expected_get_insights_query,
mock_credentials,
mock_settings,
mock_tool_context,
)
# analyze_contribution calls execute_sql twice. We need to test that the
# queries are properly constructed and call execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql):
"""Test analyze_contribution tool invocation with a query statement."""
mock_credentials = mock.MagicMock(spec=Credentials)
mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
input_data_query = "SELECT * FROM `test-dataset.test-table`"
analyze_contribution(
project_id="test-project",
input_data=input_data_query,
dimension_id_cols=["dim1", "dim2"],
contribution_metric="SUM(metric)",
is_test_col="is_test",
credentials=mock_credentials,
settings=mock_settings,
tool_context=mock_tool_context,
)
expected_create_model_query = f"""
CREATE TEMP MODEL contribution_analysis_model_test_uuid
OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = 'SUM(metric)', IS_TEST_COL = 'is_test', DIMENSION_ID_COLS = ['dim1', 'dim2'], TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = 30, PRUNING_METHOD = 'PRUNE_REDUNDANT_INSIGHTS')
AS ({input_data_query})
"""
expected_get_insights_query = """
SELECT * FROM ML.GET_INSIGHTS(MODEL contribution_analysis_model_test_uuid)
"""
assert mock_execute_sql.call_count == 2
mock_execute_sql.assert_any_call(
"test-project",
expected_create_model_query,
mock_credentials,
mock_settings,
mock_tool_context,
)
mock_execute_sql.assert_any_call(
"test-project",
expected_get_insights_query,
mock_credentials,
mock_settings,
mock_tool_context,
)
def test_analyze_contribution_with_invalid_dimension_id_cols():
"""Test analyze_contribution tool invocation with invalid dimension_id_cols."""
mock_credentials = mock.MagicMock(spec=Credentials)
mock_settings = BigQueryToolConfig()
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
result = analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", 123],
contribution_metric="metric",
is_test_col="is_test",
credentials=mock_credentials,
settings=mock_settings,
tool_context=mock_tool_context,
)
assert result["status"] == "ERROR"
assert (
"All elements in dimension_id_cols must be strings."
in result["error_details"]
)
@@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default():
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 7
assert len(tools) == 8
assert all([isinstance(tool, GoogleTool) for tool in tools])
expected_tool_names = set([
@@ -52,6 +52,7 @@ async def test_bigquery_toolset_tools_default():
"execute_sql",
"ask_data_insights",
"forecast",
"analyze_contribution",
])
actual_tool_names = set([tool.name for tool in tools])
assert actual_tool_names == expected_tool_names