chore: make readability improvements in ADK BQ tool tests

PiperOrigin-RevId: 832110063
This commit is contained in:
Google Team Member
2025-11-13 19:42:48 -08:00
committed by Copybara-Service
parent 5adbf95a0a
commit b7d571bc3f
5 changed files with 116 additions and 116 deletions
@@ -19,7 +19,9 @@ from unittest import mock
import google.adk
from google.adk.tools.bigquery.client import get_bigquery_client
import google.auth
from google.auth.exceptions import DefaultCredentialsError
from google.cloud.bigquery import client as bigquery_client
from google.oauth2.credentials import Credentials
@@ -41,7 +43,9 @@ def test_bigquery_client_project_set_explicit():
# Let's simulate that no environment variables are set, so that any project
# set in there does not interfere with this test
with mock.patch.dict(os.environ, {}, clear=True):
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
with mock.patch.object(
google.auth, "default", autospec=True
) as mock_default_auth:
# Simulate exception from default auth
mock_default_auth.side_effect = DefaultCredentialsError(
"Your default credentials were not found"
@@ -66,7 +70,9 @@ def test_bigquery_client_project_set_with_default_auth():
# Let's simulate that no environment variables are set, so that any project
# set in there does not interfere with this test
with mock.patch.dict(os.environ, {}, clear=True):
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
with mock.patch.object(
google.auth, "default", autospec=True
) as mock_default_auth:
# Simulate credentials
mock_creds = mock.create_autospec(Credentials, instance=True)
@@ -90,7 +96,9 @@ def test_bigquery_client_project_set_with_env():
with mock.patch.dict(
os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True
):
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
with mock.patch.object(
google.auth, "default", autospec=True
) as mock_default_auth:
# Simulate exception from default auth
mock_default_auth.side_effect = DefaultCredentialsError(
"Your default credentials were not found"
@@ -112,8 +120,8 @@ def test_bigquery_client_project_set_with_env():
def test_bigquery_client_user_agent_default():
"""Test BigQuery client default user agent."""
with mock.patch(
"google.cloud.bigquery.client.Connection", autospec=True
with mock.patch.object(
bigquery_client, "Connection", autospec=True
) as mock_connection:
# Trigger the BigQuery client creation
get_bigquery_client(
@@ -134,8 +142,8 @@ def test_bigquery_client_user_agent_default():
def test_bigquery_client_user_agent_custom():
"""Test BigQuery client custom user agent."""
with mock.patch(
"google.cloud.bigquery.client.Connection", autospec=True
with mock.patch.object(
bigquery_client, "Connection", autospec=True
) as mock_connection:
# Trigger the BigQuery client creation
get_bigquery_client(
@@ -158,8 +166,8 @@ def test_bigquery_client_user_agent_custom():
def test_bigquery_client_user_agent_custom_list():
"""Test BigQuery client custom user agent."""
with mock.patch(
"google.cloud.bigquery.client.Connection", autospec=True
with mock.patch.object(
bigquery_client, "Connection", autospec=True
) as mock_connection:
# Trigger the BigQuery client creation
get_bigquery_client(
@@ -36,7 +36,6 @@ class TestBigQueryCredentials:
to pass them directly without needing to provide client ID/secret.
"""
# Create a mock auth credentials object
# auth_creds = google.auth.credentials.Credentials()
auth_creds = mock.create_autospec(
google.auth.credentials.Credentials, instance=True
)
@@ -26,9 +26,7 @@ import yaml
pytest.param("test_data/ask_data_insights_penguins_highest_mass.yaml"),
],
)
@mock.patch(
"google.adk.tools.bigquery.data_insights_tool.requests.Session.post"
)
@mock.patch.object(data_insights_tool.requests.Session, "post")
def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path):
"""Runs a full integration test for the ask_data_insights pipeline using data from a specific file."""
# 1. Construct the full, absolute path to the data file
@@ -65,7 +63,7 @@ def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path):
assert result == expected_final_list
@mock.patch("google.adk.tools.bigquery.data_insights_tool._get_stream")
@mock.patch.object(data_insights_tool, "_get_stream")
def test_ask_data_insights_success(mock_get_stream):
"""Tests the success path of ask_data_insights using decorators."""
# 1. Configure the behavior of the mocked functions
@@ -92,7 +90,7 @@ def test_ask_data_insights_success(mock_get_stream):
mock_get_stream.assert_called_once()
@mock.patch("google.adk.tools.bigquery.data_insights_tool._get_stream")
@mock.patch.object(data_insights_tool, "_get_stream")
def test_ask_data_insights_handles_exception(mock_get_stream):
"""Tests the exception path of ask_data_insights using decorators."""
# 1. Configure one of the mocks to raise an error
@@ -17,16 +17,18 @@ from __future__ import annotations
import os
from unittest import mock
from google.adk.tools.bigquery import client as bq_client_lib
from google.adk.tools.bigquery import metadata_tool
from google.adk.tools.bigquery.config import BigQueryToolConfig
import google.auth
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True)
@mock.patch("google.auth.default", autospec=True)
@mock.patch.object(bigquery.Client, "list_datasets", autospec=True)
@mock.patch.object(google.auth, "default", autospec=True)
def test_list_dataset_ids_no_default_auth(
mock_default_auth, mock_list_datasets
):
@@ -53,8 +55,8 @@ def test_list_dataset_ids_no_default_auth(
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True)
@mock.patch("google.auth.default", autospec=True)
@mock.patch.object(bigquery.Client, "get_dataset", autospec=True)
@mock.patch.object(google.auth, "default", autospec=True)
def test_get_dataset_info_no_default_auth(mock_default_auth, mock_get_dataset):
"""Test get_dataset_info tool invocation involves no default auth."""
mock_credentials = mock.create_autospec(Credentials, instance=True)
@@ -80,8 +82,8 @@ def test_get_dataset_info_no_default_auth(mock_default_auth, mock_get_dataset):
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True)
@mock.patch("google.auth.default", autospec=True)
@mock.patch.object(bigquery.Client, "list_tables", autospec=True)
@mock.patch.object(google.auth, "default", autospec=True)
def test_list_table_ids_no_default_auth(mock_default_auth, mock_list_tables):
"""Test list_table_ids tool invocation involves no default auth."""
project = "my_project_id"
@@ -108,8 +110,8 @@ def test_list_table_ids_no_default_auth(mock_default_auth, mock_list_tables):
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True)
@mock.patch("google.auth.default", autospec=True)
@mock.patch.object(bigquery.Client, "get_table", autospec=True)
@mock.patch.object(google.auth, "default", autospec=True)
def test_get_table_info_no_default_auth(mock_default_auth, mock_get_table):
"""Test get_table_info tool invocation involves no default auth."""
mock_credentials = mock.create_autospec(Credentials, instance=True)
@@ -137,8 +139,8 @@ def test_get_table_info_no_default_auth(mock_default_auth, mock_get_table):
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.get_job", autospec=True)
@mock.patch("google.auth.default", autospec=True)
@mock.patch.object(bigquery.Client, "get_job", autospec=True)
@mock.patch.object(google.auth, "default", autospec=True)
def test_get_job_info_no_default_auth(mock_default_auth, mock_get_job):
"""Test get_job_info tool invocation involves no default auth."""
mock_credentials = mock.create_autospec(Credentials, instance=True)
@@ -166,9 +168,7 @@ def test_get_job_info_no_default_auth(mock_default_auth, mock_get_job):
mock_default_auth.assert_not_called()
@mock.patch(
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
)
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
def test_list_dataset_ids_bq_client_creation(mock_get_bigquery_client):
"""Test BigQuery client creation params during list_dataset_ids tool invocation."""
bq_project = "my_project_id"
@@ -189,9 +189,7 @@ def test_list_dataset_ids_bq_client_creation(mock_get_bigquery_client):
]
@mock.patch(
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
)
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
def test_get_dataset_info_bq_client_creation(mock_get_bigquery_client):
"""Test BigQuery client creation params during get_dataset_info tool invocation."""
bq_project = "my_project_id"
@@ -215,9 +213,7 @@ def test_get_dataset_info_bq_client_creation(mock_get_bigquery_client):
]
@mock.patch(
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
)
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
def test_list_table_ids_bq_client_creation(mock_get_bigquery_client):
"""Test BigQuery client creation params during list_table_ids tool invocation."""
bq_project = "my_project_id"
@@ -241,9 +237,7 @@ def test_list_table_ids_bq_client_creation(mock_get_bigquery_client):
]
@mock.patch(
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
)
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
def test_get_table_info_bq_client_creation(mock_get_bigquery_client):
"""Test BigQuery client creation params during get_table_info tool invocation."""
bq_project = "my_project_id"
@@ -268,9 +262,7 @@ def test_get_table_info_bq_client_creation(mock_get_bigquery_client):
]
@mock.patch(
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
)
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
def test_get_job_info_bq_client_creation(mock_get_bigquery_client):
"""Test BigQuery client creation params during get_table_info tool invocation."""
bq_project = "my_project_id"
@@ -20,19 +20,19 @@ import os
import textwrap
from typing import Optional
from unittest import mock
import uuid
import dateutil
import dateutil.relativedelta
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery import client as bq_client_lib
from google.adk.tools.bigquery import query_tool
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
from google.adk.tools.bigquery.query_tool import analyze_contribution
from google.adk.tools.bigquery.query_tool import detect_anomalies
from google.adk.tools.bigquery.query_tool import execute_sql
from google.adk.tools.bigquery.query_tool import forecast
from google.adk.tools.tool_context import ToolContext
import google.auth
from google.auth.exceptions import DefaultCredentialsError
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
@@ -654,7 +654,7 @@ def test_execute_sql_select_stmt(write_mode):
"_anonymous_dataset",
)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
# The mock instance
bq_client = Client.return_value
@@ -667,7 +667,7 @@ def test_execute_sql_select_stmt(write_mode):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
@@ -708,7 +708,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
tool_settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
# The mock instance
bq_client = Client.return_value
@@ -721,7 +721,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
@@ -762,7 +762,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
tool_settings = BigQueryToolConfig(write_mode=WriteMode.BLOCKED)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
# The mock instance
bq_client = Client.return_value
@@ -775,7 +775,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {
@@ -823,7 +823,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type):
"_anonymous_dataset",
)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
# The mock instance
bq_client = Client.return_value
@@ -837,7 +837,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type):
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
@@ -889,7 +889,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
"_anonymous_dataset",
)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
# The mock instance
bq_client = Client.return_value
@@ -903,7 +903,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {
@@ -927,14 +927,14 @@ def test_execute_sql_dry_run_true():
"jobReference": {"projectId": project, "location": "US"},
}
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.to_api_repr.return_value = api_repr
bq_client.query.return_value = query_job
result = execute_sql(
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context, dry_run=True
)
assert result == {"status": "SUCCESS", "dry_run_info": api_repr}
@@ -953,9 +953,9 @@ def test_execute_sql_dry_run_true():
],
)
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True)
@mock.patch("google.cloud.bigquery.Client.query", autospec=True)
@mock.patch("google.auth.default", autospec=True)
@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True)
@mock.patch.object(bigquery.Client, "query", autospec=True)
@mock.patch.object(google.auth, "default", autospec=True)
def test_execute_sql_no_default_auth(
mock_default_auth, mock_query, mock_query_and_wait, write_mode
):
@@ -987,7 +987,9 @@ def test_execute_sql_no_default_auth(
mock_query_and_wait.return_value = query_result
# Test the tool worked without invoking default auth
result = execute_sql(project, query, credentials, tool_settings, tool_context)
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": query_result}
mock_default_auth.assert_not_called()
@@ -1103,8 +1105,8 @@ def test_execute_sql_no_default_auth(
],
)
@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True)
@mock.patch("google.cloud.bigquery.Client.query", autospec=True)
@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True)
@mock.patch.object(bigquery.Client, "query", autospec=True)
def test_execute_sql_result_dtype(
mock_query, mock_query_and_wait, query, query_result, tool_result_rows
):
@@ -1128,13 +1130,13 @@ def test_execute_sql_result_dtype(
mock_query_and_wait.return_value = query_result
# Test the tool worked without invoking default auth
result = execute_sql(project, query, credentials, tool_settings, tool_context)
result = query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
assert result == {"status": "SUCCESS", "rows": tool_result_rows}
@mock.patch(
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
)
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
def test_execute_sql_bq_client_creation(mock_get_bigquery_client):
"""Test BigQuery client creation params during execute_sql tool invocation."""
project = "my_project_id"
@@ -1143,8 +1145,9 @@ def test_execute_sql_bq_client_creation(mock_get_bigquery_client):
application_name = "my-agent"
tool_settings = BigQueryToolConfig(application_name=application_name)
tool_context = mock.create_autospec(ToolContext, instance=True)
execute_sql(project, query, credentials, tool_settings, tool_context)
query_tool.execute_sql(
project, query, credentials, tool_settings, tool_context
)
mock_get_bigquery_client.assert_called_once()
assert len(mock_get_bigquery_client.call_args.kwargs) == 4
assert mock_get_bigquery_client.call_args.kwargs["project"] == project
@@ -1164,7 +1167,7 @@ def test_execute_sql_unexpected_project_id():
tool_settings = BigQueryToolConfig(compute_project_id=compute_project_id)
tool_context = mock.create_autospec(ToolContext, instance=True)
result = execute_sql(
result = query_tool.execute_sql(
tool_call_project_id, query, credentials, tool_settings, tool_context
)
assert result == {
@@ -1180,13 +1183,13 @@ def test_execute_sql_unexpected_project_id():
# AI.Forecast calls _execute_sql with a specific query statement. We need to
# test that the query is properly constructed and call _execute_sql with the
# correct parameters exactly once.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
def test_forecast_with_table_id(mock_execute_sql):
mock_credentials = mock.MagicMock(spec=Credentials)
mock_settings = BigQueryToolConfig()
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
forecast(
query_tool.forecast(
project_id="test-project",
history_data="test-dataset.test-table",
timestamp_col="ts_col",
@@ -1222,14 +1225,14 @@ def test_forecast_with_table_id(mock_execute_sql):
# AI.Forecast calls _execute_sql with a specific query statement. We need to
# test that the query is properly constructed and call _execute_sql with the
# correct parameters exactly once.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
def test_forecast_with_query_statement(mock_execute_sql):
mock_credentials = mock.MagicMock(spec=Credentials)
mock_settings = BigQueryToolConfig()
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
history_data_query = "SELECT * FROM `test-dataset.test-table`"
forecast(
query_tool.forecast(
project_id="test-project",
history_data=history_data_query,
timestamp_col="ts_col",
@@ -1264,7 +1267,7 @@ def test_forecast_with_invalid_id_cols():
mock_settings = BigQueryToolConfig()
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
result = forecast(
result = query_tool.forecast(
project_id="test-project",
history_data="test-dataset.test-table",
timestamp_col="ts_col",
@@ -1282,8 +1285,8 @@ def test_forecast_with_invalid_id_cols():
# analyze_contribution calls _execute_sql twice. We need to test that the
# queries are properly constructed and call _execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
@mock.patch.object(uuid, "uuid4", autospec=True)
def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
"""Test analyze_contribution tool invocation with a table id."""
mock_credentials = mock.MagicMock(spec=Credentials)
@@ -1291,8 +1294,7 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
analyze_contribution(
query_tool.analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", "dim2"],
@@ -1335,8 +1337,8 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
# analyze_contribution calls _execute_sql twice. We need to test that the
# queries are properly constructed and call _execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
@mock.patch.object(uuid, "uuid4", autospec=True)
def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql):
"""Test analyze_contribution tool invocation with a query statement."""
mock_credentials = mock.MagicMock(spec=Credentials)
@@ -1344,9 +1346,8 @@ def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql):
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
input_data_query = "SELECT * FROM `test-dataset.test-table`"
analyze_contribution(
query_tool.analyze_contribution(
project_id="test-project",
input_data=input_data_query,
dimension_id_cols=["dim1", "dim2"],
@@ -1392,7 +1393,7 @@ def test_analyze_contribution_with_invalid_dimension_id_cols():
mock_settings = BigQueryToolConfig()
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
result = analyze_contribution(
result = query_tool.analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", 123],
@@ -1413,8 +1414,8 @@ def test_analyze_contribution_with_invalid_dimension_id_cols():
# detect_anomalies calls _execute_sql twice. We need to test that
# the queries are properly constructed and call _execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
@mock.patch.object(uuid, "uuid4", autospec=True)
def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
"""Test time series anomaly detection tool invocation with a table id."""
mock_credentials = mock.MagicMock(spec=Credentials)
@@ -1422,9 +1423,8 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
history_data_query = "SELECT * FROM `test-dataset.test-table`"
detect_anomalies(
query_tool.detect_anomalies(
project_id="test-project",
history_data=history_data_query,
times_series_timestamp_col="ts_timestamp",
@@ -1466,8 +1466,8 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
# detect_anomalies calls _execute_sql twice. We need to test that
# the queries are properly constructed and call _execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
@mock.patch.object(uuid, "uuid4", autospec=True)
def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
"""Test time series anomaly detection tool invocation with a table id."""
mock_credentials = mock.MagicMock(spec=Credentials)
@@ -1475,9 +1475,8 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
history_data_query = "SELECT * FROM `test-dataset.test-table`"
detect_anomalies(
query_tool.detect_anomalies(
project_id="test-project",
history_data=history_data_query,
times_series_timestamp_col="ts_timestamp",
@@ -1522,8 +1521,8 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
# detect_anomalies calls _execute_sql twice. We need to test that
# the queries are properly constructed and call _execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
@mock.patch.object(uuid, "uuid4", autospec=True)
def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
"""Test time series anomaly detection tool with target data is provided."""
mock_credentials = mock.MagicMock(spec=Credentials)
@@ -1531,10 +1530,9 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
history_data_query = "SELECT * FROM `test-dataset.history-table`"
target_data_query = "SELECT * FROM `test-dataset.target-table`"
detect_anomalies(
query_tool.detect_anomalies(
project_id="test-project",
history_data=history_data_query,
times_series_timestamp_col="ts_timestamp",
@@ -1580,8 +1578,8 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
# detect_anomalies calls execute_sql twice. We need to test that
# the queries are properly constructed and call execute_sql with the correct
# parameters exactly twice.
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
@mock.patch("uuid.uuid4", autospec=True)
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
@mock.patch.object(uuid, "uuid4", autospec=True)
def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql):
"""Test time series anomaly detection tool invocation with a table id."""
mock_credentials = mock.MagicMock(spec=Credentials)
@@ -1589,9 +1587,8 @@ def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql):
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
mock_uuid.return_value = "test_uuid"
mock_execute_sql.return_value = {"status": "SUCCESS"}
history_data_query = "SELECT * FROM `test-dataset.test-table`"
detect_anomalies(
query_tool.detect_anomalies(
project_id="test-project",
history_data=history_data_query,
times_series_timestamp_col="ts_timestamp",
@@ -1637,7 +1634,7 @@ def test_detect_anomalies_with_invalid_id_cols():
mock_settings = BigQueryToolConfig()
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
result = detect_anomalies(
result = query_tool.detect_anomalies(
project_id="test-project",
history_data="test-dataset.test-table",
times_series_timestamp_col="ts_timestamp",
@@ -1680,14 +1677,14 @@ def test_execute_sql_job_labels(
tool_context = mock.create_autospec(ToolContext, instance=True)
tool_context.state.get.return_value = None
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
execute_sql(
query_tool.execute_sql(
project,
query,
credentials,
@@ -1713,7 +1710,7 @@ def test_execute_sql_job_labels(
("tool_call", "expected_label"),
[
pytest.param(
lambda tool_context: forecast(
lambda tool_context: query_tool.forecast(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
timestamp_col="ts_col",
@@ -1726,7 +1723,7 @@ def test_execute_sql_job_labels(
id="forecast",
),
pytest.param(
lambda tool_context: analyze_contribution(
lambda tool_context: query_tool.analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", "dim2"],
@@ -1740,7 +1737,7 @@ def test_execute_sql_job_labels(
id="analyze-contribution",
),
pytest.param(
lambda tool_context: detect_anomalies(
lambda tool_context: query_tool.detect_anomalies(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
times_series_timestamp_col="ts_timestamp",
@@ -1757,7 +1754,7 @@ def test_execute_sql_job_labels(
def test_ml_tool_job_labels(tool_call, expected_label):
"""Test ML tools for job label."""
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value
tool_context = mock.create_autospec(ToolContext, instance=True)
@@ -1785,14 +1782,16 @@ def test_execute_sql_max_rows_config():
tool_config = BigQueryToolConfig(max_query_result_rows=10)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
bq_client.query_and_wait.return_value = query_result[:10]
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = query_tool.execute_sql(
project, query, credentials, tool_config, tool_context
)
# Check that max_results was called with config value
bq_client.query_and_wait.assert_called_once()
@@ -1814,14 +1813,16 @@ def test_execute_sql_no_truncation():
tool_config = BigQueryToolConfig(max_query_result_rows=10)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
bq_client.query_and_wait.return_value = query_result
result = execute_sql(project, query, credentials, tool_config, tool_context)
result = query_tool.execute_sql(
project, query, credentials, tool_config, tool_context
)
# Check no truncation flag when fewer rows than limit
assert result["status"] == "SUCCESS"
@@ -1837,13 +1838,15 @@ def test_execute_sql_maximum_bytes_billed_config():
tool_config = BigQueryToolConfig(maximum_bytes_billed=11_000_000)
tool_context = mock.create_autospec(ToolContext, instance=True)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
bq_client = Client.return_value
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
execute_sql(project, query, credentials, tool_config, tool_context)
query_tool.execute_sql(
project, query, credentials, tool_config, tool_context
)
# Check that maximum_bytes_billed was called with config value
bq_client.query_and_wait.assert_called_once()
@@ -1855,7 +1858,7 @@ def test_execute_sql_maximum_bytes_billed_config():
("tool_call",),
[
pytest.param(
lambda settings, tool_context: execute_sql(
lambda settings, tool_context: query_tool.execute_sql(
project_id="test-project",
query="SELECT * FROM `test-dataset.test-table`",
credentials=mock.create_autospec(Credentials, instance=True),
@@ -1865,7 +1868,7 @@ def test_execute_sql_maximum_bytes_billed_config():
id="execute-sql",
),
pytest.param(
lambda settings, tool_context: forecast(
lambda settings, tool_context: query_tool.forecast(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
timestamp_col="ts_col",
@@ -1877,7 +1880,7 @@ def test_execute_sql_maximum_bytes_billed_config():
id="forecast",
),
pytest.param(
lambda settings, tool_context: analyze_contribution(
lambda settings, tool_context: query_tool.analyze_contribution(
project_id="test-project",
input_data="test-dataset.test-table",
dimension_id_cols=["dim1", "dim2"],
@@ -1890,7 +1893,7 @@ def test_execute_sql_maximum_bytes_billed_config():
id="analyze-contribution",
),
pytest.param(
lambda settings, tool_context: detect_anomalies(
lambda settings, tool_context: query_tool.detect_anomalies(
project_id="test-project",
history_data="SELECT * FROM `test-dataset.test-table`",
times_series_timestamp_col="ts_timestamp",