You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add BigQuery analyze_contribution tool
This change introduces a new `analyze_contribution` function in `query_tool.py` which uses BigQuery ML's `CREATE MODEL` with `CONTRIBUTION_ANALYSIS` type and `ML.GET_INSIGHTS` to analyze the contribution of different dimensions to a given metric. The new function is also added to the `bigquery_toolset`. PiperOrigin-RevId: 815849281
This commit is contained in:
committed by
Copybara-Service
parent
5c6cdcd197
commit
4bb089d386
@@ -82,6 +82,7 @@ class BigQueryToolset(BaseToolset):
|
||||
metadata_tool.list_table_ids,
|
||||
query_tool.get_execute_sql(self._tool_settings),
|
||||
query_tool.forecast,
|
||||
query_tool.analyze_contribution,
|
||||
data_insights_tool.ask_data_insights,
|
||||
]
|
||||
]
|
||||
|
||||
@@ -19,6 +19,7 @@ import json
|
||||
import types
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
@@ -892,3 +893,202 @@ def forecast(
|
||||
)
|
||||
"""
|
||||
return execute_sql(project_id, query, credentials, settings, tool_context)
|
||||
|
||||
|
||||
def analyze_contribution(
|
||||
project_id: str,
|
||||
input_data: str,
|
||||
contribution_metric: str,
|
||||
dimension_id_cols: list[str],
|
||||
is_test_col: str,
|
||||
credentials: Credentials,
|
||||
settings: BigQueryToolConfig,
|
||||
tool_context: ToolContext,
|
||||
top_k_insights: int = 30,
|
||||
pruning_method: str = "PRUNE_REDUNDANT_INSIGHTS",
|
||||
) -> dict:
|
||||
"""Run a BigQuery ML contribution analysis using ML.CREATE_MODEL and ML.GET_INSIGHTS.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the query should be
|
||||
executed.
|
||||
input_data (str): The data that contain the test and control data to
|
||||
analyze. Can be a fully qualified BigQuery table ID or a SQL query.
|
||||
dimension_id_cols (list[str]): The column names of the dimension columns.
|
||||
contribution_metric (str): The name of the column that contains the metric
|
||||
to analyze. Provides the expression to use to calculate the metric you
|
||||
are analyzing. To calculate a summable metric, the expression must be in
|
||||
the form SUM(metric_column_name), where metric_column_name is a numeric
|
||||
data type. To calculate a summable ratio metric, the expression must be
|
||||
in the form
|
||||
SUM(numerator_metric_column_name)/SUM(denominator_metric_column_name),
|
||||
where numerator_metric_column_name and denominator_metric_column_name
|
||||
are numeric data types. To calculate a summable by category metric, the
|
||||
expression must be in the form
|
||||
SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name). The
|
||||
summed column must be a numeric data type. The categorical column must
|
||||
have type BOOL, DATE, DATETIME, TIME, TIMESTAMP, STRING, or INT64.
|
||||
is_test_col (str): The name of the column to use to determine whether a
|
||||
given row is test data or control data. The column must have a BOOL data
|
||||
type.
|
||||
credentials: The credentials to use for the request.
|
||||
settings: The settings for the tool.
|
||||
tool_context: The context for the tool.
|
||||
top_k_insights (int, optional): The number of top insights to return,
|
||||
ranked by apriori support. Defaults to 30.
|
||||
pruning_method (str, optional): The method to use for pruning redundant
|
||||
insights. Can be 'NO_PRUNING' or 'PRUNE_REDUNDANT_INSIGHTS'. Defaults to
|
||||
"PRUNE_REDUNDANT_INSIGHTS".
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the result of the contribution analysis.
|
||||
|
||||
Examples:
|
||||
Analyze the contribution of different dimensions to the total sales:
|
||||
|
||||
>>> analyze_contribution(
|
||||
... project_id="my-gcp-project",
|
||||
... input_data="my-dataset.my-sales-table",
|
||||
... dimension_id_cols=["store_id", "product_category"],
|
||||
... contribution_metric="SUM(total_sales)",
|
||||
... is_test_col="is_test"
|
||||
... )
|
||||
The return is:
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"store_id": "S1",
|
||||
"product_category": "Electronics",
|
||||
"contributors": ["S1", "Electronics"],
|
||||
"metric_test": 120,
|
||||
"metric_control": 100,
|
||||
"difference": 20,
|
||||
"relative_difference": 0.2,
|
||||
"unexpected_difference": 5,
|
||||
"relative_unexpected_difference": 0.043,
|
||||
"apriori_support": 0.15
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Analyze the contribution of different dimensions to the total sales using
|
||||
a SQL query as input:
|
||||
|
||||
>>> analyze_contribution(
|
||||
... project_id="my-gcp-project",
|
||||
... input_data="SELECT store_id, product_category, total_sales, "
|
||||
... "is_test FROM `my-project.my-dataset.my-sales-table` "
|
||||
... "WHERE transaction_date > '2025-01-01'"
|
||||
... dimension_id_cols=["store_id", "product_category"],
|
||||
... contribution_metric="SUM(total_sales)",
|
||||
... is_test_col="is_test"
|
||||
... )
|
||||
The return is:
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"store_id": "S2",
|
||||
"product_category": "Groceries",
|
||||
"contributors": ["S2", "Groceries"],
|
||||
"metric_test": 250,
|
||||
"metric_control": 200,
|
||||
"difference": 50,
|
||||
"relative_difference": 0.25,
|
||||
"unexpected_difference": 10,
|
||||
"relative_unexpected_difference": 0.041,
|
||||
"apriori_support": 0.22
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
if not all(isinstance(item, str) for item in dimension_id_cols):
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": "All elements in dimension_id_cols must be strings.",
|
||||
}
|
||||
|
||||
# Generate a unique temporary model name
|
||||
model_name = (
|
||||
f"contribution_analysis_model_{str(uuid.uuid4()).replace('-', '_')}"
|
||||
)
|
||||
|
||||
id_cols_str = "[" + ", ".join([f"'{col}'" for col in dimension_id_cols]) + "]"
|
||||
options = [
|
||||
"MODEL_TYPE = 'CONTRIBUTION_ANALYSIS'",
|
||||
f"CONTRIBUTION_METRIC = '{contribution_metric}'",
|
||||
f"IS_TEST_COL = '{is_test_col}'",
|
||||
f"DIMENSION_ID_COLS = {id_cols_str}",
|
||||
]
|
||||
|
||||
options.append(f"TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = {top_k_insights}")
|
||||
|
||||
upper_pruning = pruning_method.upper()
|
||||
if upper_pruning not in ["NO_PRUNING", "PRUNE_REDUNDANT_INSIGHTS"]:
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": f"Invalid pruning_method: {pruning_method}",
|
||||
}
|
||||
options.append(f"PRUNING_METHOD = '{upper_pruning}'")
|
||||
|
||||
options_str = ", ".join(options)
|
||||
|
||||
trimmed_upper_input_data = input_data.strip().upper()
|
||||
if trimmed_upper_input_data.startswith(
|
||||
"SELECT"
|
||||
) or trimmed_upper_input_data.startswith("WITH"):
|
||||
input_data_source = f"({input_data})"
|
||||
else:
|
||||
input_data_source = f"SELECT * FROM `{input_data}`"
|
||||
|
||||
create_model_query = f"""
|
||||
CREATE TEMP MODEL {model_name}
|
||||
OPTIONS ({options_str})
|
||||
AS {input_data_source}
|
||||
"""
|
||||
|
||||
get_insights_query = f"""
|
||||
SELECT * FROM ML.GET_INSIGHTS(MODEL {model_name})
|
||||
"""
|
||||
|
||||
# Create a session and run the create model query.
|
||||
original_write_mode = settings.write_mode
|
||||
try:
|
||||
if settings.write_mode == WriteMode.BLOCKED:
|
||||
raise ValueError("analyze_contribution is not allowed in this session.")
|
||||
elif original_write_mode != WriteMode.PROTECTED:
|
||||
# Running create temp model requires a session. So we set the write mode
|
||||
# to PROTECTED to run the create model query and job query in the same
|
||||
# session.
|
||||
settings.write_mode = WriteMode.PROTECTED
|
||||
|
||||
result = execute_sql(
|
||||
project_id,
|
||||
create_model_query,
|
||||
credentials,
|
||||
settings,
|
||||
tool_context,
|
||||
)
|
||||
if result["status"] != "SUCCESS":
|
||||
return result
|
||||
|
||||
result = execute_sql(
|
||||
project_id,
|
||||
get_insights_query,
|
||||
credentials,
|
||||
settings,
|
||||
tool_context,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": f"Error during analyze_contribution: {str(ex)}",
|
||||
}
|
||||
finally:
|
||||
# Restore the original write mode.
|
||||
settings.write_mode == original_write_mode
|
||||
|
||||
return result
|
||||
|
||||
@@ -28,6 +28,7 @@ from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.bigquery.config import WriteMode
|
||||
from google.adk.tools.bigquery.query_tool import analyze_contribution
|
||||
from google.adk.tools.bigquery.query_tool import execute_sql
|
||||
from google.adk.tools.bigquery.query_tool import forecast
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
@@ -874,7 +875,8 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
|
||||
"""Test execute_sql tool for non-SELECT query when writes are protected.
|
||||
|
||||
This is a special case when the destination table is a persistent/permananent
|
||||
one and the protected write is enabled. In this case the operation should fail.
|
||||
one and the protected write is enabled. In this case the operation should
|
||||
fail.
|
||||
"""
|
||||
project = "my_project"
|
||||
query_result = []
|
||||
@@ -1272,3 +1274,130 @@ def test_forecast_with_invalid_id_cols():
|
||||
|
||||
assert result["status"] == "ERROR"
|
||||
assert "All elements in id_cols must be strings." in result["error_details"]
|
||||
|
||||
|
||||
# analyze_contribution calls execute_sql twice. We need to test that the
|
||||
# queries are properly constructed and call execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
|
||||
"""Test analyze_contribution tool invocation with a table id."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data="test-dataset.test-table",
|
||||
dimension_id_cols=["dim1", "dim2"],
|
||||
contribution_metric="SUM(metric)",
|
||||
is_test_col="is_test",
|
||||
credentials=mock_credentials,
|
||||
settings=mock_settings,
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
expected_create_model_query = """
|
||||
CREATE TEMP MODEL contribution_analysis_model_test_uuid
|
||||
OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = 'SUM(metric)', IS_TEST_COL = 'is_test', DIMENSION_ID_COLS = ['dim1', 'dim2'], TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = 30, PRUNING_METHOD = 'PRUNE_REDUNDANT_INSIGHTS')
|
||||
AS SELECT * FROM `test-dataset.test-table`
|
||||
"""
|
||||
|
||||
expected_get_insights_query = """
|
||||
SELECT * FROM ML.GET_INSIGHTS(MODEL contribution_analysis_model_test_uuid)
|
||||
"""
|
||||
|
||||
assert mock_execute_sql.call_count == 2
|
||||
mock_execute_sql.assert_any_call(
|
||||
"test-project",
|
||||
expected_create_model_query,
|
||||
mock_credentials,
|
||||
mock_settings,
|
||||
mock_tool_context,
|
||||
)
|
||||
mock_execute_sql.assert_any_call(
|
||||
"test-project",
|
||||
expected_get_insights_query,
|
||||
mock_credentials,
|
||||
mock_settings,
|
||||
mock_tool_context,
|
||||
)
|
||||
|
||||
|
||||
# analyze_contribution calls execute_sql twice. We need to test that the
|
||||
# queries are properly constructed and call execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool.execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql):
|
||||
"""Test analyze_contribution tool invocation with a query statement."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
mock_settings = BigQueryToolConfig(write_mode=WriteMode.PROTECTED)
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
input_data_query = "SELECT * FROM `test-dataset.test-table`"
|
||||
analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data=input_data_query,
|
||||
dimension_id_cols=["dim1", "dim2"],
|
||||
contribution_metric="SUM(metric)",
|
||||
is_test_col="is_test",
|
||||
credentials=mock_credentials,
|
||||
settings=mock_settings,
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
expected_create_model_query = f"""
|
||||
CREATE TEMP MODEL contribution_analysis_model_test_uuid
|
||||
OPTIONS (MODEL_TYPE = 'CONTRIBUTION_ANALYSIS', CONTRIBUTION_METRIC = 'SUM(metric)', IS_TEST_COL = 'is_test', DIMENSION_ID_COLS = ['dim1', 'dim2'], TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = 30, PRUNING_METHOD = 'PRUNE_REDUNDANT_INSIGHTS')
|
||||
AS ({input_data_query})
|
||||
"""
|
||||
|
||||
expected_get_insights_query = """
|
||||
SELECT * FROM ML.GET_INSIGHTS(MODEL contribution_analysis_model_test_uuid)
|
||||
"""
|
||||
|
||||
assert mock_execute_sql.call_count == 2
|
||||
mock_execute_sql.assert_any_call(
|
||||
"test-project",
|
||||
expected_create_model_query,
|
||||
mock_credentials,
|
||||
mock_settings,
|
||||
mock_tool_context,
|
||||
)
|
||||
mock_execute_sql.assert_any_call(
|
||||
"test-project",
|
||||
expected_get_insights_query,
|
||||
mock_credentials,
|
||||
mock_settings,
|
||||
mock_tool_context,
|
||||
)
|
||||
|
||||
|
||||
def test_analyze_contribution_with_invalid_dimension_id_cols():
|
||||
"""Test analyze_contribution tool invocation with invalid dimension_id_cols."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
mock_settings = BigQueryToolConfig()
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
result = analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data="test-dataset.test-table",
|
||||
dimension_id_cols=["dim1", 123],
|
||||
contribution_metric="metric",
|
||||
is_test_col="is_test",
|
||||
credentials=mock_credentials,
|
||||
settings=mock_settings,
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
assert result["status"] == "ERROR"
|
||||
assert (
|
||||
"All elements in dimension_id_cols must be strings."
|
||||
in result["error_details"]
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ async def test_bigquery_toolset_tools_default():
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
|
||||
assert len(tools) == 7
|
||||
assert len(tools) == 8
|
||||
assert all([isinstance(tool, GoogleTool) for tool in tools])
|
||||
|
||||
expected_tool_names = set([
|
||||
@@ -52,6 +52,7 @@ async def test_bigquery_toolset_tools_default():
|
||||
"execute_sql",
|
||||
"ask_data_insights",
|
||||
"forecast",
|
||||
"analyze_contribution",
|
||||
])
|
||||
actual_tool_names = set([tool.name for tool in tools])
|
||||
assert actual_tool_names == expected_tool_names
|
||||
|
||||
Reference in New Issue
Block a user