test: Add unit tests for execute_sql tool

This change introduces unit tests in which the behavior of the tool is asserted for various query types in various write modes through a mocked BigQuery client.

PiperOrigin-RevId: 770653117
This commit is contained in:
Google Team Member
2025-06-12 07:52:08 -07:00
committed by Copybara-Service
parent 0a5cf45a75
commit 2ff9b1f639
@@ -16,12 +16,16 @@ from __future__ import annotations
import textwrap
from typing import Optional
from unittest import mock
from google.adk.tools import BaseTool
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
from google.adk.tools.bigquery.query_tool import execute_sql
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
import pytest
@@ -218,3 +222,123 @@ async def test_execute_sql_declaration_write(tool_config):
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
- First run "DROP TABLE", followed by "CREATE TABLE".
- To insert data into a table, use "INSERT INTO" statement.""")
@pytest.mark.parametrize(
("write_mode",),
[
pytest.param(
WriteMode.BLOCKED,
id="blocked",
),
pytest.param(
WriteMode.ALLOWED,
id="allowed",
),
],
)
def test_execute_sql_select_stmt(write_mode):
"""Test execute_sql tool for SELECT query when writes are blocked."""
project = "my_project"
query = "SELECT 123 AS num"
statement_type = "SELECT"
query_result = [{"num": 123}]
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=write_mode)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
# The mock instance
bq_client = Client.return_value
# Simulate the result of query API
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
# Simulate the result of query_and_wait API
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config)
assert result == {"status": "SUCCESS", "rows": query_result}
@pytest.mark.parametrize(
("query", "statement_type"),
[
pytest.param(
"CREATE TABLE my_dataset.my_table AS SELECT 123 AS num",
"CREATE_AS_SELECT",
id="create-as-select",
),
pytest.param(
"DROP TABLE my_dataset.my_table",
"DROP_TABLE",
id="drop-table",
),
],
)
def test_execute_sql_non_select_stmt_write_allowed(query, statement_type):
"""Test execute_sql tool for SELECT query when writes are blocked."""
project = "my_project"
query_result = []
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
# The mock instance
bq_client = Client.return_value
# Simulate the result of query API
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
# Simulate the result of query_and_wait API
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config)
assert result == {"status": "SUCCESS", "rows": query_result}
@pytest.mark.parametrize(
("query", "statement_type"),
[
pytest.param(
"CREATE TABLE my_dataset.my_table AS SELECT 123 AS num",
"CREATE_AS_SELECT",
id="create-as-select",
),
pytest.param(
"DROP TABLE my_dataset.my_table",
"DROP_TABLE",
id="drop-table",
),
],
)
def test_execute_sql_non_select_stmt_write_blocked(query, statement_type):
"""Test execute_sql tool for SELECT query when writes are blocked."""
project = "my_project"
query_result = []
credentials = mock.create_autospec(Credentials, instance=True)
tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED)
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
# The mock instance
bq_client = Client.return_value
# Simulate the result of query API
query_job = mock.create_autospec(bigquery.QueryJob)
query_job.statement_type = statement_type
bq_client.query.return_value = query_job
# Simulate the result of query_and_wait API
bq_client.query_and_wait.return_value = query_result
# Test the tool
result = execute_sql(project, query, credentials, tool_config)
assert result == {
"status": "ERROR",
"error_details": "Read-only mode only supports SELECT statements.",
}