You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: make readability improvements in ADK BQ tool tests
PiperOrigin-RevId: 832110063
This commit is contained in:
committed by
Copybara-Service
parent
5adbf95a0a
commit
b7d571bc3f
@@ -19,7 +19,9 @@ from unittest import mock
|
||||
|
||||
import google.adk
|
||||
from google.adk.tools.bigquery.client import get_bigquery_client
|
||||
import google.auth
|
||||
from google.auth.exceptions import DefaultCredentialsError
|
||||
from google.cloud.bigquery import client as bigquery_client
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
||||
@@ -41,7 +43,9 @@ def test_bigquery_client_project_set_explicit():
|
||||
# Let's simulate that no environment variables are set, so that any project
|
||||
# set in there does not interfere with this test
|
||||
with mock.patch.dict(os.environ, {}, clear=True):
|
||||
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
|
||||
with mock.patch.object(
|
||||
google.auth, "default", autospec=True
|
||||
) as mock_default_auth:
|
||||
# Simulate exception from default auth
|
||||
mock_default_auth.side_effect = DefaultCredentialsError(
|
||||
"Your default credentials were not found"
|
||||
@@ -66,7 +70,9 @@ def test_bigquery_client_project_set_with_default_auth():
|
||||
# Let's simulate that no environment variables are set, so that any project
|
||||
# set in there does not interfere with this test
|
||||
with mock.patch.dict(os.environ, {}, clear=True):
|
||||
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
|
||||
with mock.patch.object(
|
||||
google.auth, "default", autospec=True
|
||||
) as mock_default_auth:
|
||||
# Simulate credentials
|
||||
mock_creds = mock.create_autospec(Credentials, instance=True)
|
||||
|
||||
@@ -90,7 +96,9 @@ def test_bigquery_client_project_set_with_env():
|
||||
with mock.patch.dict(
|
||||
os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True
|
||||
):
|
||||
with mock.patch("google.auth.default", autospec=True) as mock_default_auth:
|
||||
with mock.patch.object(
|
||||
google.auth, "default", autospec=True
|
||||
) as mock_default_auth:
|
||||
# Simulate exception from default auth
|
||||
mock_default_auth.side_effect = DefaultCredentialsError(
|
||||
"Your default credentials were not found"
|
||||
@@ -112,8 +120,8 @@ def test_bigquery_client_project_set_with_env():
|
||||
|
||||
def test_bigquery_client_user_agent_default():
|
||||
"""Test BigQuery client default user agent."""
|
||||
with mock.patch(
|
||||
"google.cloud.bigquery.client.Connection", autospec=True
|
||||
with mock.patch.object(
|
||||
bigquery_client, "Connection", autospec=True
|
||||
) as mock_connection:
|
||||
# Trigger the BigQuery client creation
|
||||
get_bigquery_client(
|
||||
@@ -134,8 +142,8 @@ def test_bigquery_client_user_agent_default():
|
||||
|
||||
def test_bigquery_client_user_agent_custom():
|
||||
"""Test BigQuery client custom user agent."""
|
||||
with mock.patch(
|
||||
"google.cloud.bigquery.client.Connection", autospec=True
|
||||
with mock.patch.object(
|
||||
bigquery_client, "Connection", autospec=True
|
||||
) as mock_connection:
|
||||
# Trigger the BigQuery client creation
|
||||
get_bigquery_client(
|
||||
@@ -158,8 +166,8 @@ def test_bigquery_client_user_agent_custom():
|
||||
|
||||
def test_bigquery_client_user_agent_custom_list():
|
||||
"""Test BigQuery client custom user agent."""
|
||||
with mock.patch(
|
||||
"google.cloud.bigquery.client.Connection", autospec=True
|
||||
with mock.patch.object(
|
||||
bigquery_client, "Connection", autospec=True
|
||||
) as mock_connection:
|
||||
# Trigger the BigQuery client creation
|
||||
get_bigquery_client(
|
||||
|
||||
@@ -36,7 +36,6 @@ class TestBigQueryCredentials:
|
||||
to pass them directly without needing to provide client ID/secret.
|
||||
"""
|
||||
# Create a mock auth credentials object
|
||||
# auth_creds = google.auth.credentials.Credentials()
|
||||
auth_creds = mock.create_autospec(
|
||||
google.auth.credentials.Credentials, instance=True
|
||||
)
|
||||
|
||||
@@ -26,9 +26,7 @@ import yaml
|
||||
pytest.param("test_data/ask_data_insights_penguins_highest_mass.yaml"),
|
||||
],
|
||||
)
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.data_insights_tool.requests.Session.post"
|
||||
)
|
||||
@mock.patch.object(data_insights_tool.requests.Session, "post")
|
||||
def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path):
|
||||
"""Runs a full integration test for the ask_data_insights pipeline using data from a specific file."""
|
||||
# 1. Construct the full, absolute path to the data file
|
||||
@@ -65,7 +63,7 @@ def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path):
|
||||
assert result == expected_final_list
|
||||
|
||||
|
||||
@mock.patch("google.adk.tools.bigquery.data_insights_tool._get_stream")
|
||||
@mock.patch.object(data_insights_tool, "_get_stream")
|
||||
def test_ask_data_insights_success(mock_get_stream):
|
||||
"""Tests the success path of ask_data_insights using decorators."""
|
||||
# 1. Configure the behavior of the mocked functions
|
||||
@@ -92,7 +90,7 @@ def test_ask_data_insights_success(mock_get_stream):
|
||||
mock_get_stream.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("google.adk.tools.bigquery.data_insights_tool._get_stream")
|
||||
@mock.patch.object(data_insights_tool, "_get_stream")
|
||||
def test_ask_data_insights_handles_exception(mock_get_stream):
|
||||
"""Tests the exception path of ask_data_insights using decorators."""
|
||||
# 1. Configure one of the mocks to raise an error
|
||||
|
||||
@@ -17,16 +17,18 @@ from __future__ import annotations
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.tools.bigquery import client as bq_client_lib
|
||||
from google.adk.tools.bigquery import metadata_tool
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
import google.auth
|
||||
from google.auth.exceptions import DefaultCredentialsError
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True)
|
||||
@mock.patch("google.auth.default", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "list_datasets", autospec=True)
|
||||
@mock.patch.object(google.auth, "default", autospec=True)
|
||||
def test_list_dataset_ids_no_default_auth(
|
||||
mock_default_auth, mock_list_datasets
|
||||
):
|
||||
@@ -53,8 +55,8 @@ def test_list_dataset_ids_no_default_auth(
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True)
|
||||
@mock.patch("google.auth.default", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "get_dataset", autospec=True)
|
||||
@mock.patch.object(google.auth, "default", autospec=True)
|
||||
def test_get_dataset_info_no_default_auth(mock_default_auth, mock_get_dataset):
|
||||
"""Test get_dataset_info tool invocation involves no default auth."""
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
@@ -80,8 +82,8 @@ def test_get_dataset_info_no_default_auth(mock_default_auth, mock_get_dataset):
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True)
|
||||
@mock.patch("google.auth.default", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "list_tables", autospec=True)
|
||||
@mock.patch.object(google.auth, "default", autospec=True)
|
||||
def test_list_table_ids_no_default_auth(mock_default_auth, mock_list_tables):
|
||||
"""Test list_table_ids tool invocation involves no default auth."""
|
||||
project = "my_project_id"
|
||||
@@ -108,8 +110,8 @@ def test_list_table_ids_no_default_auth(mock_default_auth, mock_list_tables):
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True)
|
||||
@mock.patch("google.auth.default", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "get_table", autospec=True)
|
||||
@mock.patch.object(google.auth, "default", autospec=True)
|
||||
def test_get_table_info_no_default_auth(mock_default_auth, mock_get_table):
|
||||
"""Test get_table_info tool invocation involves no default auth."""
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
@@ -137,8 +139,8 @@ def test_get_table_info_no_default_auth(mock_default_auth, mock_get_table):
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.get_job", autospec=True)
|
||||
@mock.patch("google.auth.default", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "get_job", autospec=True)
|
||||
@mock.patch.object(google.auth, "default", autospec=True)
|
||||
def test_get_job_info_no_default_auth(mock_default_auth, mock_get_job):
|
||||
"""Test get_job_info tool invocation involves no default auth."""
|
||||
mock_credentials = mock.create_autospec(Credentials, instance=True)
|
||||
@@ -166,9 +168,7 @@ def test_get_job_info_no_default_auth(mock_default_auth, mock_get_job):
|
||||
mock_default_auth.assert_not_called()
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
|
||||
)
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_list_dataset_ids_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during list_dataset_ids tool invocation."""
|
||||
bq_project = "my_project_id"
|
||||
@@ -189,9 +189,7 @@ def test_list_dataset_ids_bq_client_creation(mock_get_bigquery_client):
|
||||
]
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
|
||||
)
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_get_dataset_info_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during get_dataset_info tool invocation."""
|
||||
bq_project = "my_project_id"
|
||||
@@ -215,9 +213,7 @@ def test_get_dataset_info_bq_client_creation(mock_get_bigquery_client):
|
||||
]
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
|
||||
)
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_list_table_ids_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during list_table_ids tool invocation."""
|
||||
bq_project = "my_project_id"
|
||||
@@ -241,9 +237,7 @@ def test_list_table_ids_bq_client_creation(mock_get_bigquery_client):
|
||||
]
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
|
||||
)
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_get_table_info_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during get_table_info tool invocation."""
|
||||
bq_project = "my_project_id"
|
||||
@@ -268,9 +262,7 @@ def test_get_table_info_bq_client_creation(mock_get_bigquery_client):
|
||||
]
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
|
||||
)
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_get_job_info_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during get_table_info tool invocation."""
|
||||
bq_project = "my_project_id"
|
||||
|
||||
@@ -20,19 +20,19 @@ import os
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
import uuid
|
||||
|
||||
import dateutil
|
||||
import dateutil.relativedelta
|
||||
from google.adk.tools.base_tool import BaseTool
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
from google.adk.tools.bigquery import client as bq_client_lib
|
||||
from google.adk.tools.bigquery import query_tool
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.bigquery.config import WriteMode
|
||||
from google.adk.tools.bigquery.query_tool import analyze_contribution
|
||||
from google.adk.tools.bigquery.query_tool import detect_anomalies
|
||||
from google.adk.tools.bigquery.query_tool import execute_sql
|
||||
from google.adk.tools.bigquery.query_tool import forecast
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
import google.auth
|
||||
from google.auth.exceptions import DefaultCredentialsError
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
@@ -654,7 +654,7 @@ def test_execute_sql_select_stmt(write_mode):
|
||||
"_anonymous_dataset",
|
||||
)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
# The mock instance
|
||||
bq_client = Client.return_value
|
||||
|
||||
@@ -667,7 +667,7 @@ def test_execute_sql_select_stmt(write_mode):
|
||||
bq_client.query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "rows": query_result}
|
||||
@@ -708,7 +708,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
|
||||
tool_settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
# The mock instance
|
||||
bq_client = Client.return_value
|
||||
|
||||
@@ -721,7 +721,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
|
||||
bq_client.query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "rows": query_result}
|
||||
@@ -762,7 +762,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
|
||||
tool_settings = BigQueryToolConfig(write_mode=WriteMode.BLOCKED)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
# The mock instance
|
||||
bq_client = Client.return_value
|
||||
|
||||
@@ -775,7 +775,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
|
||||
bq_client.query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {
|
||||
@@ -823,7 +823,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type):
|
||||
"_anonymous_dataset",
|
||||
)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
# The mock instance
|
||||
bq_client = Client.return_value
|
||||
|
||||
@@ -837,7 +837,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type):
|
||||
bq_client.query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "rows": query_result}
|
||||
@@ -889,7 +889,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
|
||||
"_anonymous_dataset",
|
||||
)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
# The mock instance
|
||||
bq_client = Client.return_value
|
||||
|
||||
@@ -903,7 +903,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target(
|
||||
bq_client.query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {
|
||||
@@ -927,14 +927,14 @@ def test_execute_sql_dry_run_true():
|
||||
"jobReference": {"projectId": project, "location": "US"},
|
||||
}
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
bq_client = Client.return_value
|
||||
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.to_api_repr.return_value = api_repr
|
||||
bq_client.query.return_value = query_job
|
||||
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context, dry_run=True
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "dry_run_info": api_repr}
|
||||
@@ -953,9 +953,9 @@ def test_execute_sql_dry_run_true():
|
||||
],
|
||||
)
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.query", autospec=True)
|
||||
@mock.patch("google.auth.default", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "query", autospec=True)
|
||||
@mock.patch.object(google.auth, "default", autospec=True)
|
||||
def test_execute_sql_no_default_auth(
|
||||
mock_default_auth, mock_query, mock_query_and_wait, write_mode
|
||||
):
|
||||
@@ -987,7 +987,9 @@ def test_execute_sql_no_default_auth(
|
||||
mock_query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool worked without invoking default auth
|
||||
result = execute_sql(project, query, credentials, tool_settings, tool_context)
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "rows": query_result}
|
||||
mock_default_auth.assert_not_called()
|
||||
|
||||
@@ -1103,8 +1105,8 @@ def test_execute_sql_no_default_auth(
|
||||
],
|
||||
)
|
||||
@mock.patch.dict(os.environ, {}, clear=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True)
|
||||
@mock.patch("google.cloud.bigquery.Client.query", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True)
|
||||
@mock.patch.object(bigquery.Client, "query", autospec=True)
|
||||
def test_execute_sql_result_dtype(
|
||||
mock_query, mock_query_and_wait, query, query_result, tool_result_rows
|
||||
):
|
||||
@@ -1128,13 +1130,13 @@ def test_execute_sql_result_dtype(
|
||||
mock_query_and_wait.return_value = query_result
|
||||
|
||||
# Test the tool worked without invoking default auth
|
||||
result = execute_sql(project, query, credentials, tool_settings, tool_context)
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {"status": "SUCCESS", "rows": tool_result_rows}
|
||||
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.bigquery.client.get_bigquery_client", autospec=True
|
||||
)
|
||||
@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True)
|
||||
def test_execute_sql_bq_client_creation(mock_get_bigquery_client):
|
||||
"""Test BigQuery client creation params during execute_sql tool invocation."""
|
||||
project = "my_project_id"
|
||||
@@ -1143,8 +1145,9 @@ def test_execute_sql_bq_client_creation(mock_get_bigquery_client):
|
||||
application_name = "my-agent"
|
||||
tool_settings = BigQueryToolConfig(application_name=application_name)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
execute_sql(project, query, credentials, tool_settings, tool_context)
|
||||
query_tool.execute_sql(
|
||||
project, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
mock_get_bigquery_client.assert_called_once()
|
||||
assert len(mock_get_bigquery_client.call_args.kwargs) == 4
|
||||
assert mock_get_bigquery_client.call_args.kwargs["project"] == project
|
||||
@@ -1164,7 +1167,7 @@ def test_execute_sql_unexpected_project_id():
|
||||
tool_settings = BigQueryToolConfig(compute_project_id=compute_project_id)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
result = execute_sql(
|
||||
result = query_tool.execute_sql(
|
||||
tool_call_project_id, query, credentials, tool_settings, tool_context
|
||||
)
|
||||
assert result == {
|
||||
@@ -1180,13 +1183,13 @@ def test_execute_sql_unexpected_project_id():
|
||||
# AI.Forecast calls _execute_sql with a specific query statement. We need to
|
||||
# test that the query is properly constructed and call _execute_sql with the
|
||||
# correct parameters exactly once.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
def test_forecast_with_table_id(mock_execute_sql):
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
mock_settings = BigQueryToolConfig()
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
forecast(
|
||||
query_tool.forecast(
|
||||
project_id="test-project",
|
||||
history_data="test-dataset.test-table",
|
||||
timestamp_col="ts_col",
|
||||
@@ -1222,14 +1225,14 @@ def test_forecast_with_table_id(mock_execute_sql):
|
||||
# AI.Forecast calls _execute_sql with a specific query statement. We need to
|
||||
# test that the query is properly constructed and call _execute_sql with the
|
||||
# correct parameters exactly once.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
def test_forecast_with_query_statement(mock_execute_sql):
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
mock_settings = BigQueryToolConfig()
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
history_data_query = "SELECT * FROM `test-dataset.test-table`"
|
||||
forecast(
|
||||
query_tool.forecast(
|
||||
project_id="test-project",
|
||||
history_data=history_data_query,
|
||||
timestamp_col="ts_col",
|
||||
@@ -1264,7 +1267,7 @@ def test_forecast_with_invalid_id_cols():
|
||||
mock_settings = BigQueryToolConfig()
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
result = forecast(
|
||||
result = query_tool.forecast(
|
||||
project_id="test-project",
|
||||
history_data="test-dataset.test-table",
|
||||
timestamp_col="ts_col",
|
||||
@@ -1282,8 +1285,8 @@ def test_forecast_with_invalid_id_cols():
|
||||
# analyze_contribution calls _execute_sql twice. We need to test that the
|
||||
# queries are properly constructed and call _execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
@mock.patch.object(uuid, "uuid4", autospec=True)
|
||||
def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
|
||||
"""Test analyze_contribution tool invocation with a table id."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
@@ -1291,8 +1294,7 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
analyze_contribution(
|
||||
query_tool.analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data="test-dataset.test-table",
|
||||
dimension_id_cols=["dim1", "dim2"],
|
||||
@@ -1335,8 +1337,8 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql):
|
||||
# analyze_contribution calls _execute_sql twice. We need to test that the
|
||||
# queries are properly constructed and call _execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
@mock.patch.object(uuid, "uuid4", autospec=True)
|
||||
def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql):
|
||||
"""Test analyze_contribution tool invocation with a query statement."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
@@ -1344,9 +1346,8 @@ def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql):
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
input_data_query = "SELECT * FROM `test-dataset.test-table`"
|
||||
analyze_contribution(
|
||||
query_tool.analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data=input_data_query,
|
||||
dimension_id_cols=["dim1", "dim2"],
|
||||
@@ -1392,7 +1393,7 @@ def test_analyze_contribution_with_invalid_dimension_id_cols():
|
||||
mock_settings = BigQueryToolConfig()
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
result = analyze_contribution(
|
||||
result = query_tool.analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data="test-dataset.test-table",
|
||||
dimension_id_cols=["dim1", 123],
|
||||
@@ -1413,8 +1414,8 @@ def test_analyze_contribution_with_invalid_dimension_id_cols():
|
||||
# detect_anomalies calls _execute_sql twice. We need to test that
|
||||
# the queries are properly constructed and call _execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
@mock.patch.object(uuid, "uuid4", autospec=True)
|
||||
def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
|
||||
"""Test time series anomaly detection tool invocation with a table id."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
@@ -1422,9 +1423,8 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
history_data_query = "SELECT * FROM `test-dataset.test-table`"
|
||||
detect_anomalies(
|
||||
query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data=history_data_query,
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
@@ -1466,8 +1466,8 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql):
|
||||
# detect_anomalies calls _execute_sql twice. We need to test that
|
||||
# the queries are properly constructed and call _execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
@mock.patch.object(uuid, "uuid4", autospec=True)
|
||||
def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
|
||||
"""Test time series anomaly detection tool invocation with a table id."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
@@ -1475,9 +1475,8 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
history_data_query = "SELECT * FROM `test-dataset.test-table`"
|
||||
detect_anomalies(
|
||||
query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data=history_data_query,
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
@@ -1522,8 +1521,8 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql):
|
||||
# detect_anomalies calls _execute_sql twice. We need to test that
|
||||
# the queries are properly constructed and call _execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
@mock.patch.object(uuid, "uuid4", autospec=True)
|
||||
def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
|
||||
"""Test time series anomaly detection tool with target data is provided."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
@@ -1531,10 +1530,9 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
history_data_query = "SELECT * FROM `test-dataset.history-table`"
|
||||
target_data_query = "SELECT * FROM `test-dataset.target-table`"
|
||||
detect_anomalies(
|
||||
query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data=history_data_query,
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
@@ -1580,8 +1578,8 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql):
|
||||
# detect_anomalies calls execute_sql twice. We need to test that
|
||||
# the queries are properly constructed and call execute_sql with the correct
|
||||
# parameters exactly twice.
|
||||
@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True)
|
||||
@mock.patch("uuid.uuid4", autospec=True)
|
||||
@mock.patch.object(query_tool, "_execute_sql", autospec=True)
|
||||
@mock.patch.object(uuid, "uuid4", autospec=True)
|
||||
def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql):
|
||||
"""Test time series anomaly detection tool invocation with a table id."""
|
||||
mock_credentials = mock.MagicMock(spec=Credentials)
|
||||
@@ -1589,9 +1587,8 @@ def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql):
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
mock_uuid.return_value = "test_uuid"
|
||||
mock_execute_sql.return_value = {"status": "SUCCESS"}
|
||||
|
||||
history_data_query = "SELECT * FROM `test-dataset.test-table`"
|
||||
detect_anomalies(
|
||||
query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data=history_data_query,
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
@@ -1637,7 +1634,7 @@ def test_detect_anomalies_with_invalid_id_cols():
|
||||
mock_settings = BigQueryToolConfig()
|
||||
mock_tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
result = detect_anomalies(
|
||||
result = query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data="test-dataset.test-table",
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
@@ -1680,14 +1677,14 @@ def test_execute_sql_job_labels(
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
tool_context.state.get.return_value = None
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
bq_client = Client.return_value
|
||||
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.statement_type = statement_type
|
||||
bq_client.query.return_value = query_job
|
||||
|
||||
execute_sql(
|
||||
query_tool.execute_sql(
|
||||
project,
|
||||
query,
|
||||
credentials,
|
||||
@@ -1713,7 +1710,7 @@ def test_execute_sql_job_labels(
|
||||
("tool_call", "expected_label"),
|
||||
[
|
||||
pytest.param(
|
||||
lambda tool_context: forecast(
|
||||
lambda tool_context: query_tool.forecast(
|
||||
project_id="test-project",
|
||||
history_data="SELECT * FROM `test-dataset.test-table`",
|
||||
timestamp_col="ts_col",
|
||||
@@ -1726,7 +1723,7 @@ def test_execute_sql_job_labels(
|
||||
id="forecast",
|
||||
),
|
||||
pytest.param(
|
||||
lambda tool_context: analyze_contribution(
|
||||
lambda tool_context: query_tool.analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data="test-dataset.test-table",
|
||||
dimension_id_cols=["dim1", "dim2"],
|
||||
@@ -1740,7 +1737,7 @@ def test_execute_sql_job_labels(
|
||||
id="analyze-contribution",
|
||||
),
|
||||
pytest.param(
|
||||
lambda tool_context: detect_anomalies(
|
||||
lambda tool_context: query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data="SELECT * FROM `test-dataset.test-table`",
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
@@ -1757,7 +1754,7 @@ def test_execute_sql_job_labels(
|
||||
def test_ml_tool_job_labels(tool_call, expected_label):
|
||||
"""Test ML tools for job label."""
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
bq_client = Client.return_value
|
||||
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
@@ -1785,14 +1782,16 @@ def test_execute_sql_max_rows_config():
|
||||
tool_config = BigQueryToolConfig(max_query_result_rows=10)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
bq_client = Client.return_value
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.statement_type = statement_type
|
||||
bq_client.query.return_value = query_job
|
||||
bq_client.query_and_wait.return_value = query_result[:10]
|
||||
|
||||
result = execute_sql(project, query, credentials, tool_config, tool_context)
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_config, tool_context
|
||||
)
|
||||
|
||||
# Check that max_results was called with config value
|
||||
bq_client.query_and_wait.assert_called_once()
|
||||
@@ -1814,14 +1813,16 @@ def test_execute_sql_no_truncation():
|
||||
tool_config = BigQueryToolConfig(max_query_result_rows=10)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
bq_client = Client.return_value
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.statement_type = statement_type
|
||||
bq_client.query.return_value = query_job
|
||||
bq_client.query_and_wait.return_value = query_result
|
||||
|
||||
result = execute_sql(project, query, credentials, tool_config, tool_context)
|
||||
result = query_tool.execute_sql(
|
||||
project, query, credentials, tool_config, tool_context
|
||||
)
|
||||
|
||||
# Check no truncation flag when fewer rows than limit
|
||||
assert result["status"] == "SUCCESS"
|
||||
@@ -1837,13 +1838,15 @@ def test_execute_sql_maximum_bytes_billed_config():
|
||||
tool_config = BigQueryToolConfig(maximum_bytes_billed=11_000_000)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
with mock.patch.object(bigquery, "Client", autospec=True) as Client:
|
||||
bq_client = Client.return_value
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.statement_type = statement_type
|
||||
bq_client.query.return_value = query_job
|
||||
|
||||
execute_sql(project, query, credentials, tool_config, tool_context)
|
||||
query_tool.execute_sql(
|
||||
project, query, credentials, tool_config, tool_context
|
||||
)
|
||||
|
||||
# Check that maximum_bytes_billed was called with config value
|
||||
bq_client.query_and_wait.assert_called_once()
|
||||
@@ -1855,7 +1858,7 @@ def test_execute_sql_maximum_bytes_billed_config():
|
||||
("tool_call",),
|
||||
[
|
||||
pytest.param(
|
||||
lambda settings, tool_context: execute_sql(
|
||||
lambda settings, tool_context: query_tool.execute_sql(
|
||||
project_id="test-project",
|
||||
query="SELECT * FROM `test-dataset.test-table`",
|
||||
credentials=mock.create_autospec(Credentials, instance=True),
|
||||
@@ -1865,7 +1868,7 @@ def test_execute_sql_maximum_bytes_billed_config():
|
||||
id="execute-sql",
|
||||
),
|
||||
pytest.param(
|
||||
lambda settings, tool_context: forecast(
|
||||
lambda settings, tool_context: query_tool.forecast(
|
||||
project_id="test-project",
|
||||
history_data="SELECT * FROM `test-dataset.test-table`",
|
||||
timestamp_col="ts_col",
|
||||
@@ -1877,7 +1880,7 @@ def test_execute_sql_maximum_bytes_billed_config():
|
||||
id="forecast",
|
||||
),
|
||||
pytest.param(
|
||||
lambda settings, tool_context: analyze_contribution(
|
||||
lambda settings, tool_context: query_tool.analyze_contribution(
|
||||
project_id="test-project",
|
||||
input_data="test-dataset.test-table",
|
||||
dimension_id_cols=["dim1", "dim2"],
|
||||
@@ -1890,7 +1893,7 @@ def test_execute_sql_maximum_bytes_billed_config():
|
||||
id="analyze-contribution",
|
||||
),
|
||||
pytest.param(
|
||||
lambda settings, tool_context: detect_anomalies(
|
||||
lambda settings, tool_context: query_tool.detect_anomalies(
|
||||
project_id="test-project",
|
||||
history_data="SELECT * FROM `test-dataset.test-table`",
|
||||
times_series_timestamp_col="ts_timestamp",
|
||||
|
||||
Reference in New Issue
Block a user