diff --git a/tests/unittests/tools/bigquery/test_bigquery_client.py b/tests/unittests/tools/bigquery/test_bigquery_client.py index e36d4b06..b56873a0 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_client.py +++ b/tests/unittests/tools/bigquery/test_bigquery_client.py @@ -19,7 +19,9 @@ from unittest import mock import google.adk from google.adk.tools.bigquery.client import get_bigquery_client +import google.auth from google.auth.exceptions import DefaultCredentialsError +from google.cloud.bigquery import client as bigquery_client from google.oauth2.credentials import Credentials @@ -41,7 +43,9 @@ def test_bigquery_client_project_set_explicit(): # Let's simulate that no environment variables are set, so that any project # set in there does not interfere with this test with mock.patch.dict(os.environ, {}, clear=True): - with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + with mock.patch.object( + google.auth, "default", autospec=True + ) as mock_default_auth: # Simulate exception from default auth mock_default_auth.side_effect = DefaultCredentialsError( "Your default credentials were not found" @@ -66,7 +70,9 @@ def test_bigquery_client_project_set_with_default_auth(): # Let's simulate that no environment variables are set, so that any project # set in there does not interfere with this test with mock.patch.dict(os.environ, {}, clear=True): - with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + with mock.patch.object( + google.auth, "default", autospec=True + ) as mock_default_auth: # Simulate credentials mock_creds = mock.create_autospec(Credentials, instance=True) @@ -90,7 +96,9 @@ def test_bigquery_client_project_set_with_env(): with mock.patch.dict( os.environ, {"GOOGLE_CLOUD_PROJECT": "test-gcp-project"}, clear=True ): - with mock.patch("google.auth.default", autospec=True) as mock_default_auth: + with mock.patch.object( + google.auth, "default", autospec=True + ) as mock_default_auth: # Simulate exception from default auth mock_default_auth.side_effect = DefaultCredentialsError( "Your default credentials were not found" @@ -112,8 +120,8 @@ def test_bigquery_client_project_set_with_env(): def test_bigquery_client_user_agent_default(): """Test BigQuery client default user agent.""" - with mock.patch( - "google.cloud.bigquery.client.Connection", autospec=True + with mock.patch.object( + bigquery_client, "Connection", autospec=True ) as mock_connection: # Trigger the BigQuery client creation get_bigquery_client( @@ -134,8 +142,8 @@ def test_bigquery_client_user_agent_default(): def test_bigquery_client_user_agent_custom(): """Test BigQuery client custom user agent.""" - with mock.patch( - "google.cloud.bigquery.client.Connection", autospec=True + with mock.patch.object( + bigquery_client, "Connection", autospec=True ) as mock_connection: # Trigger the BigQuery client creation get_bigquery_client( @@ -158,8 +166,8 @@ def test_bigquery_client_user_agent_custom(): def test_bigquery_client_user_agent_custom_list(): """Test BigQuery client custom user agent.""" - with mock.patch( - "google.cloud.bigquery.client.Connection", autospec=True + with mock.patch.object( + bigquery_client, "Connection", autospec=True ) as mock_connection: # Trigger the BigQuery client creation get_bigquery_client( diff --git a/tests/unittests/tools/bigquery/test_bigquery_credentials.py b/tests/unittests/tools/bigquery/test_bigquery_credentials.py index 4201cf3b..2342446c 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_credentials.py +++ b/tests/unittests/tools/bigquery/test_bigquery_credentials.py @@ -36,7 +36,6 @@ class TestBigQueryCredentials: to pass them directly without needing to provide client ID/secret. """ # Create a mock auth credentials object - # auth_creds = google.auth.credentials.Credentials() auth_creds = mock.create_autospec( google.auth.credentials.Credentials, instance=True ) diff --git a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py index 2c52d1e6..f7d0fa06 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_data_insights_tool.py @@ -26,9 +26,7 @@ import yaml pytest.param("test_data/ask_data_insights_penguins_highest_mass.yaml"), ], ) -@mock.patch( - "google.adk.tools.bigquery.data_insights_tool.requests.Session.post" -) +@mock.patch.object(data_insights_tool.requests.Session, "post") def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path): """Runs a full integration test for the ask_data_insights pipeline using data from a specific file.""" # 1. Construct the full, absolute path to the data file @@ -65,7 +63,7 @@ def test_ask_data_insights_pipeline_from_file(mock_post, case_file_path): assert result == expected_final_list -@mock.patch("google.adk.tools.bigquery.data_insights_tool._get_stream") +@mock.patch.object(data_insights_tool, "_get_stream") def test_ask_data_insights_success(mock_get_stream): """Tests the success path of ask_data_insights using decorators.""" # 1. Configure the behavior of the mocked functions @@ -92,7 +90,7 @@ def test_ask_data_insights_success(mock_get_stream): mock_get_stream.assert_called_once() -@mock.patch("google.adk.tools.bigquery.data_insights_tool._get_stream") +@mock.patch.object(data_insights_tool, "_get_stream") def test_ask_data_insights_handles_exception(mock_get_stream): """Tests the exception path of ask_data_insights using decorators.""" # 1. Configure one of the mocks to raise an error diff --git a/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py index 8727ce20..197884ce 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_metadata_tool.py @@ -17,16 +17,18 @@ from __future__ import annotations import os from unittest import mock +from google.adk.tools.bigquery import client as bq_client_lib from google.adk.tools.bigquery import metadata_tool from google.adk.tools.bigquery.config import BigQueryToolConfig +import google.auth from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery from google.oauth2.credentials import Credentials @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.list_datasets", autospec=True) -@mock.patch("google.auth.default", autospec=True) +@mock.patch.object(bigquery.Client, "list_datasets", autospec=True) +@mock.patch.object(google.auth, "default", autospec=True) def test_list_dataset_ids_no_default_auth( mock_default_auth, mock_list_datasets ): @@ -53,8 +55,8 @@ def test_list_dataset_ids_no_default_auth( @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.get_dataset", autospec=True) -@mock.patch("google.auth.default", autospec=True) +@mock.patch.object(bigquery.Client, "get_dataset", autospec=True) +@mock.patch.object(google.auth, "default", autospec=True) def test_get_dataset_info_no_default_auth(mock_default_auth, mock_get_dataset): """Test get_dataset_info tool invocation involves no default auth.""" mock_credentials = mock.create_autospec(Credentials, instance=True) @@ -80,8 +82,8 @@ def test_get_dataset_info_no_default_auth(mock_default_auth, mock_get_dataset): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.list_tables", autospec=True) -@mock.patch("google.auth.default", autospec=True) +@mock.patch.object(bigquery.Client, "list_tables", autospec=True) +@mock.patch.object(google.auth, "default", autospec=True) def test_list_table_ids_no_default_auth(mock_default_auth, mock_list_tables): """Test list_table_ids tool invocation involves no default auth.""" project = "my_project_id" @@ -108,8 +110,8 @@ def test_list_table_ids_no_default_auth(mock_default_auth, mock_list_tables): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.get_table", autospec=True) -@mock.patch("google.auth.default", autospec=True) +@mock.patch.object(bigquery.Client, "get_table", autospec=True) +@mock.patch.object(google.auth, "default", autospec=True) def test_get_table_info_no_default_auth(mock_default_auth, mock_get_table): """Test get_table_info tool invocation involves no default auth.""" mock_credentials = mock.create_autospec(Credentials, instance=True) @@ -137,8 +139,8 @@ def test_get_table_info_no_default_auth(mock_default_auth, mock_get_table): @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.get_job", autospec=True) -@mock.patch("google.auth.default", autospec=True) +@mock.patch.object(bigquery.Client, "get_job", autospec=True) +@mock.patch.object(google.auth, "default", autospec=True) def test_get_job_info_no_default_auth(mock_default_auth, mock_get_job): """Test get_job_info tool invocation involves no default auth.""" mock_credentials = mock.create_autospec(Credentials, instance=True) @@ -166,9 +168,7 @@ def test_get_job_info_no_default_auth(mock_default_auth, mock_get_job): mock_default_auth.assert_not_called() -@mock.patch( - "google.adk.tools.bigquery.client.get_bigquery_client", autospec=True -) +@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_list_dataset_ids_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during list_dataset_ids tool invocation.""" bq_project = "my_project_id" @@ -189,9 +189,7 @@ def test_list_dataset_ids_bq_client_creation(mock_get_bigquery_client): ] -@mock.patch( - "google.adk.tools.bigquery.client.get_bigquery_client", autospec=True -) +@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_get_dataset_info_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during get_dataset_info tool invocation.""" bq_project = "my_project_id" @@ -215,9 +213,7 @@ def test_get_dataset_info_bq_client_creation(mock_get_bigquery_client): ] -@mock.patch( - "google.adk.tools.bigquery.client.get_bigquery_client", autospec=True -) +@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_list_table_ids_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during list_table_ids tool invocation.""" bq_project = "my_project_id" @@ -241,9 +237,7 @@ def test_list_table_ids_bq_client_creation(mock_get_bigquery_client): ] -@mock.patch( - "google.adk.tools.bigquery.client.get_bigquery_client", autospec=True -) +@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_get_table_info_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during get_table_info tool invocation.""" bq_project = "my_project_id" @@ -268,9 +262,7 @@ def test_get_table_info_bq_client_creation(mock_get_bigquery_client): ] -@mock.patch( - "google.adk.tools.bigquery.client.get_bigquery_client", autospec=True -) +@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_get_job_info_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during get_table_info tool invocation.""" bq_project = "my_project_id" diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index b88715a1..5482ad0b 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -20,19 +20,19 @@ import os import textwrap from typing import Optional from unittest import mock +import uuid import dateutil import dateutil.relativedelta from google.adk.tools.base_tool import BaseTool from google.adk.tools.bigquery import BigQueryCredentialsConfig from google.adk.tools.bigquery import BigQueryToolset +from google.adk.tools.bigquery import client as bq_client_lib +from google.adk.tools.bigquery import query_tool from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode -from google.adk.tools.bigquery.query_tool import analyze_contribution -from google.adk.tools.bigquery.query_tool import detect_anomalies -from google.adk.tools.bigquery.query_tool import execute_sql -from google.adk.tools.bigquery.query_tool import forecast from google.adk.tools.tool_context import ToolContext +import google.auth from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery from google.oauth2.credentials import Credentials @@ -654,7 +654,7 @@ def test_execute_sql_select_stmt(write_mode): "_anonymous_dataset", ) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: # The mock instance bq_client = Client.return_value @@ -667,7 +667,7 @@ def test_execute_sql_select_stmt(write_mode): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql( + result = query_tool.execute_sql( project, query, credentials, tool_settings, tool_context ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -708,7 +708,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): tool_settings = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: # The mock instance bq_client = Client.return_value @@ -721,7 +721,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql( + result = query_tool.execute_sql( project, query, credentials, tool_settings, tool_context ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -762,7 +762,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): tool_settings = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: # The mock instance bq_client = Client.return_value @@ -775,7 +775,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql( + result = query_tool.execute_sql( project, query, credentials, tool_settings, tool_context ) assert result == { @@ -823,7 +823,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type): "_anonymous_dataset", ) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: # The mock instance bq_client = Client.return_value @@ -837,7 +837,7 @@ def test_execute_sql_non_select_stmt_write_protected(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql( + result = query_tool.execute_sql( project, query, credentials, tool_settings, tool_context ) assert result == {"status": "SUCCESS", "rows": query_result} @@ -889,7 +889,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target( "_anonymous_dataset", ) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: # The mock instance bq_client = Client.return_value @@ -903,7 +903,7 @@ def test_execute_sql_non_select_stmt_write_protected_persistent_target( bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql( + result = query_tool.execute_sql( project, query, credentials, tool_settings, tool_context ) assert result == { @@ -927,14 +927,14 @@ def test_execute_sql_dry_run_true(): "jobReference": {"projectId": project, "location": "US"}, } - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: bq_client = Client.return_value query_job = mock.create_autospec(bigquery.QueryJob) query_job.to_api_repr.return_value = api_repr bq_client.query.return_value = query_job - result = execute_sql( + result = query_tool.execute_sql( project, query, credentials, tool_settings, tool_context, dry_run=True ) assert result == {"status": "SUCCESS", "dry_run_info": api_repr} @@ -953,9 +953,9 @@ def test_execute_sql_dry_run_true(): ], ) @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) -@mock.patch("google.cloud.bigquery.Client.query", autospec=True) -@mock.patch("google.auth.default", autospec=True) +@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True) +@mock.patch.object(bigquery.Client, "query", autospec=True) +@mock.patch.object(google.auth, "default", autospec=True) def test_execute_sql_no_default_auth( mock_default_auth, mock_query, mock_query_and_wait, write_mode ): @@ -987,7 +987,9 @@ def test_execute_sql_no_default_auth( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_settings, tool_context) + result = query_tool.execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": query_result} mock_default_auth.assert_not_called() @@ -1103,8 +1105,8 @@ def test_execute_sql_no_default_auth( ], ) @mock.patch.dict(os.environ, {}, clear=True) -@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) -@mock.patch("google.cloud.bigquery.Client.query", autospec=True) +@mock.patch.object(bigquery.Client, "query_and_wait", autospec=True) +@mock.patch.object(bigquery.Client, "query", autospec=True) def test_execute_sql_result_dtype( mock_query, mock_query_and_wait, query, query_result, tool_result_rows ): @@ -1128,13 +1130,13 @@ def test_execute_sql_result_dtype( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_settings, tool_context) + result = query_tool.execute_sql( + project, query, credentials, tool_settings, tool_context + ) assert result == {"status": "SUCCESS", "rows": tool_result_rows} -@mock.patch( - "google.adk.tools.bigquery.client.get_bigquery_client", autospec=True -) +@mock.patch.object(bq_client_lib, "get_bigquery_client", autospec=True) def test_execute_sql_bq_client_creation(mock_get_bigquery_client): """Test BigQuery client creation params during execute_sql tool invocation.""" project = "my_project_id" @@ -1143,8 +1145,9 @@ def test_execute_sql_bq_client_creation(mock_get_bigquery_client): application_name = "my-agent" tool_settings = BigQueryToolConfig(application_name=application_name) tool_context = mock.create_autospec(ToolContext, instance=True) - - execute_sql(project, query, credentials, tool_settings, tool_context) + query_tool.execute_sql( + project, query, credentials, tool_settings, tool_context + ) mock_get_bigquery_client.assert_called_once() assert len(mock_get_bigquery_client.call_args.kwargs) == 4 assert mock_get_bigquery_client.call_args.kwargs["project"] == project @@ -1164,7 +1167,7 @@ def test_execute_sql_unexpected_project_id(): tool_settings = BigQueryToolConfig(compute_project_id=compute_project_id) tool_context = mock.create_autospec(ToolContext, instance=True) - result = execute_sql( + result = query_tool.execute_sql( tool_call_project_id, query, credentials, tool_settings, tool_context ) assert result == { @@ -1180,13 +1183,13 @@ def test_execute_sql_unexpected_project_id(): # AI.Forecast calls _execute_sql with a specific query statement. We need to # test that the query is properly constructed and call _execute_sql with the # correct parameters exactly once. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) def test_forecast_with_table_id(mock_execute_sql): mock_credentials = mock.MagicMock(spec=Credentials) mock_settings = BigQueryToolConfig() mock_tool_context = mock.create_autospec(ToolContext, instance=True) - forecast( + query_tool.forecast( project_id="test-project", history_data="test-dataset.test-table", timestamp_col="ts_col", @@ -1222,14 +1225,14 @@ def test_forecast_with_table_id(mock_execute_sql): # AI.Forecast calls _execute_sql with a specific query statement. We need to # test that the query is properly constructed and call _execute_sql with the # correct parameters exactly once. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) def test_forecast_with_query_statement(mock_execute_sql): mock_credentials = mock.MagicMock(spec=Credentials) mock_settings = BigQueryToolConfig() mock_tool_context = mock.create_autospec(ToolContext, instance=True) history_data_query = "SELECT * FROM `test-dataset.test-table`" - forecast( + query_tool.forecast( project_id="test-project", history_data=history_data_query, timestamp_col="ts_col", @@ -1264,7 +1267,7 @@ def test_forecast_with_invalid_id_cols(): mock_settings = BigQueryToolConfig() mock_tool_context = mock.create_autospec(ToolContext, instance=True) - result = forecast( + result = query_tool.forecast( project_id="test-project", history_data="test-dataset.test-table", timestamp_col="ts_col", @@ -1282,8 +1285,8 @@ def test_forecast_with_invalid_id_cols(): # analyze_contribution calls _execute_sql twice. We need to test that the # queries are properly constructed and call _execute_sql with the correct # parameters exactly twice. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) -@mock.patch("uuid.uuid4", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +@mock.patch.object(uuid, "uuid4", autospec=True) def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql): """Test analyze_contribution tool invocation with a table id.""" mock_credentials = mock.MagicMock(spec=Credentials) @@ -1291,8 +1294,7 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - - analyze_contribution( + query_tool.analyze_contribution( project_id="test-project", input_data="test-dataset.test-table", dimension_id_cols=["dim1", "dim2"], @@ -1335,8 +1337,8 @@ def test_analyze_contribution_with_table_id(mock_uuid, mock_execute_sql): # analyze_contribution calls _execute_sql twice. We need to test that the # queries are properly constructed and call _execute_sql with the correct # parameters exactly twice. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) -@mock.patch("uuid.uuid4", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +@mock.patch.object(uuid, "uuid4", autospec=True) def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql): """Test analyze_contribution tool invocation with a query statement.""" mock_credentials = mock.MagicMock(spec=Credentials) @@ -1344,9 +1346,8 @@ def test_analyze_contribution_with_query_statement(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - input_data_query = "SELECT * FROM `test-dataset.test-table`" - analyze_contribution( + query_tool.analyze_contribution( project_id="test-project", input_data=input_data_query, dimension_id_cols=["dim1", "dim2"], @@ -1392,7 +1393,7 @@ def test_analyze_contribution_with_invalid_dimension_id_cols(): mock_settings = BigQueryToolConfig() mock_tool_context = mock.create_autospec(ToolContext, instance=True) - result = analyze_contribution( + result = query_tool.analyze_contribution( project_id="test-project", input_data="test-dataset.test-table", dimension_id_cols=["dim1", 123], @@ -1413,8 +1414,8 @@ def test_analyze_contribution_with_invalid_dimension_id_cols(): # detect_anomalies calls _execute_sql twice. We need to test that # the queries are properly constructed and call _execute_sql with the correct # parameters exactly twice. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) -@mock.patch("uuid.uuid4", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +@mock.patch.object(uuid, "uuid4", autospec=True) def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql): """Test time series anomaly detection tool invocation with a table id.""" mock_credentials = mock.MagicMock(spec=Credentials) @@ -1422,9 +1423,8 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.test-table`" - detect_anomalies( + query_tool.detect_anomalies( project_id="test-project", history_data=history_data_query, times_series_timestamp_col="ts_timestamp", @@ -1466,8 +1466,8 @@ def test_detect_anomalies_with_table_id(mock_uuid, mock_execute_sql): # detect_anomalies calls _execute_sql twice. We need to test that # the queries are properly constructed and call _execute_sql with the correct # parameters exactly twice. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) -@mock.patch("uuid.uuid4", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +@mock.patch.object(uuid, "uuid4", autospec=True) def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql): """Test time series anomaly detection tool invocation with a table id.""" mock_credentials = mock.MagicMock(spec=Credentials) @@ -1475,9 +1475,8 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.test-table`" - detect_anomalies( + query_tool.detect_anomalies( project_id="test-project", history_data=history_data_query, times_series_timestamp_col="ts_timestamp", @@ -1522,8 +1521,8 @@ def test_detect_anomalies_with_custom_params(mock_uuid, mock_execute_sql): # detect_anomalies calls _execute_sql twice. We need to test that # the queries are properly constructed and call _execute_sql with the correct # parameters exactly twice. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) -@mock.patch("uuid.uuid4", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +@mock.patch.object(uuid, "uuid4", autospec=True) def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql): """Test time series anomaly detection tool with target data is provided.""" mock_credentials = mock.MagicMock(spec=Credentials) @@ -1531,10 +1530,9 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.history-table`" target_data_query = "SELECT * FROM `test-dataset.target-table`" - detect_anomalies( + query_tool.detect_anomalies( project_id="test-project", history_data=history_data_query, times_series_timestamp_col="ts_timestamp", @@ -1580,8 +1578,8 @@ def test_detect_anomalies_on_target_table(mock_uuid, mock_execute_sql): # detect_anomalies calls execute_sql twice. We need to test that # the queries are properly constructed and call execute_sql with the correct # parameters exactly twice. -@mock.patch("google.adk.tools.bigquery.query_tool._execute_sql", autospec=True) -@mock.patch("uuid.uuid4", autospec=True) +@mock.patch.object(query_tool, "_execute_sql", autospec=True) +@mock.patch.object(uuid, "uuid4", autospec=True) def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql): """Test time series anomaly detection tool invocation with a table id.""" mock_credentials = mock.MagicMock(spec=Credentials) @@ -1589,9 +1587,8 @@ def test_detect_anomalies_with_str_table_id(mock_uuid, mock_execute_sql): mock_tool_context = mock.create_autospec(ToolContext, instance=True) mock_uuid.return_value = "test_uuid" mock_execute_sql.return_value = {"status": "SUCCESS"} - history_data_query = "SELECT * FROM `test-dataset.test-table`" - detect_anomalies( + query_tool.detect_anomalies( project_id="test-project", history_data=history_data_query, times_series_timestamp_col="ts_timestamp", @@ -1637,7 +1634,7 @@ def test_detect_anomalies_with_invalid_id_cols(): mock_settings = BigQueryToolConfig() mock_tool_context = mock.create_autospec(ToolContext, instance=True) - result = detect_anomalies( + result = query_tool.detect_anomalies( project_id="test-project", history_data="test-dataset.test-table", times_series_timestamp_col="ts_timestamp", @@ -1680,14 +1677,14 @@ def test_execute_sql_job_labels( tool_context = mock.create_autospec(ToolContext, instance=True) tool_context.state.get.return_value = None - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: bq_client = Client.return_value query_job = mock.create_autospec(bigquery.QueryJob) query_job.statement_type = statement_type bq_client.query.return_value = query_job - execute_sql( + query_tool.execute_sql( project, query, credentials, @@ -1713,7 +1710,7 @@ def test_execute_sql_job_labels( ("tool_call", "expected_label"), [ pytest.param( - lambda tool_context: forecast( + lambda tool_context: query_tool.forecast( project_id="test-project", history_data="SELECT * FROM `test-dataset.test-table`", timestamp_col="ts_col", @@ -1726,7 +1723,7 @@ def test_execute_sql_job_labels( id="forecast", ), pytest.param( - lambda tool_context: analyze_contribution( + lambda tool_context: query_tool.analyze_contribution( project_id="test-project", input_data="test-dataset.test-table", dimension_id_cols=["dim1", "dim2"], @@ -1740,7 +1737,7 @@ def test_execute_sql_job_labels( id="analyze-contribution", ), pytest.param( - lambda tool_context: detect_anomalies( + lambda tool_context: query_tool.detect_anomalies( project_id="test-project", history_data="SELECT * FROM `test-dataset.test-table`", times_series_timestamp_col="ts_timestamp", @@ -1757,7 +1754,7 @@ def test_execute_sql_job_labels( def test_ml_tool_job_labels(tool_call, expected_label): """Test ML tools for job label.""" - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: bq_client = Client.return_value tool_context = mock.create_autospec(ToolContext, instance=True) @@ -1785,14 +1782,16 @@ def test_execute_sql_max_rows_config(): tool_config = BigQueryToolConfig(max_query_result_rows=10) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: bq_client = Client.return_value query_job = mock.create_autospec(bigquery.QueryJob) query_job.statement_type = statement_type bq_client.query.return_value = query_job bq_client.query_and_wait.return_value = query_result[:10] - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = query_tool.execute_sql( + project, query, credentials, tool_config, tool_context + ) # Check that max_results was called with config value bq_client.query_and_wait.assert_called_once() @@ -1814,14 +1813,16 @@ def test_execute_sql_no_truncation(): tool_config = BigQueryToolConfig(max_query_result_rows=10) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: bq_client = Client.return_value query_job = mock.create_autospec(bigquery.QueryJob) query_job.statement_type = statement_type bq_client.query.return_value = query_job bq_client.query_and_wait.return_value = query_result - result = execute_sql(project, query, credentials, tool_config, tool_context) + result = query_tool.execute_sql( + project, query, credentials, tool_config, tool_context + ) # Check no truncation flag when fewer rows than limit assert result["status"] == "SUCCESS" @@ -1837,13 +1838,15 @@ def test_execute_sql_maximum_bytes_billed_config(): tool_config = BigQueryToolConfig(maximum_bytes_billed=11_000_000) tool_context = mock.create_autospec(ToolContext, instance=True) - with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + with mock.patch.object(bigquery, "Client", autospec=True) as Client: bq_client = Client.return_value query_job = mock.create_autospec(bigquery.QueryJob) query_job.statement_type = statement_type bq_client.query.return_value = query_job - execute_sql(project, query, credentials, tool_config, tool_context) + query_tool.execute_sql( + project, query, credentials, tool_config, tool_context + ) # Check that maximum_bytes_billed was called with config value bq_client.query_and_wait.assert_called_once() @@ -1855,7 +1858,7 @@ def test_execute_sql_maximum_bytes_billed_config(): ("tool_call",), [ pytest.param( - lambda settings, tool_context: execute_sql( + lambda settings, tool_context: query_tool.execute_sql( project_id="test-project", query="SELECT * FROM `test-dataset.test-table`", credentials=mock.create_autospec(Credentials, instance=True), @@ -1865,7 +1868,7 @@ def test_execute_sql_maximum_bytes_billed_config(): id="execute-sql", ), pytest.param( - lambda settings, tool_context: forecast( + lambda settings, tool_context: query_tool.forecast( project_id="test-project", history_data="SELECT * FROM `test-dataset.test-table`", timestamp_col="ts_col", @@ -1877,7 +1880,7 @@ def test_execute_sql_maximum_bytes_billed_config(): id="forecast", ), pytest.param( - lambda settings, tool_context: analyze_contribution( + lambda settings, tool_context: query_tool.analyze_contribution( project_id="test-project", input_data="test-dataset.test-table", dimension_id_cols=["dim1", "dim2"], @@ -1890,7 +1893,7 @@ def test_execute_sql_maximum_bytes_billed_config(): id="analyze-contribution", ), pytest.param( - lambda settings, tool_context: detect_anomalies( + lambda settings, tool_context: query_tool.detect_anomalies( project_id="test-project", history_data="SELECT * FROM `test-dataset.test-table`", times_series_timestamp_col="ts_timestamp",