From dc43d518c90b44932b3fdedd33fca9e6c87704e2 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 27 Jun 2025 11:42:40 -0700 Subject: [PATCH] feat: Support protected write in BigQuery `execute_sql` tool This change adds a new enum value which the agent builder can pass in the `BigQueryToolConfig` to allow limited writes to the `execute_sql` tool. PiperOrigin-RevId: 776661744 --- contributing/samples/bigquery/agent.py | 6 +- src/google/adk/tools/bigquery/config.py | 12 +- src/google/adk/tools/bigquery/query_tool.py | 318 +++++++++++- .../bigquery/test_bigquery_query_tool.py | 479 +++++++++++++++++- 4 files changed, 786 insertions(+), 29 deletions(-) diff --git a/contributing/samples/bigquery/agent.py b/contributing/samples/bigquery/agent.py index c1b265c0..b78f7968 100644 --- a/contributing/samples/bigquery/agent.py +++ b/contributing/samples/bigquery/agent.py @@ -26,7 +26,11 @@ import google.auth CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2 -# Define BigQuery tool config +# Define BigQuery tool config with write mode set to allowed. Note that this is +# only to demonstrate the full capability of the BigQuery tools. In production +# you may want to change to BLOCKED (default write mode, effectively makes the +# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a +# BigQuery session) write mode. tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py index 606f86e3..a6f8eeb5 100644 --- a/src/google/adk/tools/bigquery/config.py +++ b/src/google/adk/tools/bigquery/config.py @@ -26,10 +26,20 @@ class WriteMode(Enum): BLOCKED = 'blocked' """No write operations are allowed. - + This mode implies that only read (i.e. SELECT query) operations are allowed. """ + PROTECTED = 'protected' + """Only protected write operations are allowed in a BigQuery session. + + In this mode write operations in the anonymous dataset of a BigQuery session + are allowed. For example, a temporaray table can be created, manipulated and + deleted in the anonymous dataset during Agent interaction, while protecting + permanent tables from being modified or deleted. To learn more about BigQuery + sessions, see https://cloud.google.com/bigquery/docs/sessions-intro. + """ + ALLOWED = 'allowed' """All write operations are allowed.""" diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 147d0b4d..7406d9a4 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -20,10 +20,12 @@ from google.auth.credentials import Credentials from google.cloud import bigquery from . import client +from ..tool_context import ToolContext from .config import BigQueryToolConfig from .config import WriteMode MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50 +BIGQUERY_SESSION_INFO_KEY = "bigquery_session_info" def execute_sql( @@ -31,14 +33,17 @@ def execute_sql( query: str, credentials: Credentials, config: BigQueryToolConfig, + tool_context: ToolContext, ) -> dict: - """Run a BigQuery SQL query in the project and return the result. + """Run a BigQuery or BigQuery ML SQL query in the project and return the result. Args: project_id (str): The GCP project id in which the query should be executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. + config (BigQueryToolConfig): The configuration for the tool. + tool_context (ToolContext): The context for the tool. Returns: dict: Dictionary representing the result of the query. @@ -49,11 +54,11 @@ def execute_sql( Examples: Fetch data or insights from a table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") { - "status": "ERROR", + "status": "SUCCESS", "rows": [ { "island": "Dream", @@ -72,23 +77,87 @@ def execute_sql( """ try: + # Get BigQuery client bq_client = client.get_bigquery_client( project=project_id, credentials=credentials ) + + # BigQuery connection properties where applicable + bq_connection_properties = None + if not config or config.write_mode == WriteMode.BLOCKED: - query_job = bq_client.query( + dry_run_query_job = bq_client.query( query, project=project_id, job_config=bigquery.QueryJobConfig(dry_run=True), ) - if query_job.statement_type != "SELECT": + if dry_run_query_job.statement_type != "SELECT": return { "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } + elif config.write_mode == WriteMode.PROTECTED: + # In protected write mode, write operation only to a temporary artifact is + # allowed. This artifact must have been created in a BigQuery session. In + # such a scenario the session info (session id and the anonymous dataset + # containing the artifact) is persisted in the tool context. + bq_session_info = tool_context.state.get(BIGQUERY_SESSION_INFO_KEY, None) + if bq_session_info: + bq_session_id, bq_session_dataset_id = bq_session_info + else: + session_creator_job = bq_client.query( + "SELECT 1", + project=project_id, + job_config=bigquery.QueryJobConfig( + dry_run=True, create_session=True + ), + ) + bq_session_id = session_creator_job.session_info.session_id + bq_session_dataset_id = session_creator_job.destination.dataset_id + # Remember the BigQuery session info for subsequent queries + tool_context.state[BIGQUERY_SESSION_INFO_KEY] = ( + bq_session_id, + bq_session_dataset_id, + ) + + # Session connection property will be set in the query execution + bq_connection_properties = [ + bigquery.ConnectionProperty("session_id", bq_session_id) + ] + + # Check the query type w.r.t. the BigQuery session + dry_run_query_job = bq_client.query( + query, + project=project_id, + job_config=bigquery.QueryJobConfig( + dry_run=True, + connection_properties=bq_connection_properties, + ), + ) + if ( + dry_run_query_job.statement_type != "SELECT" + and dry_run_query_job.destination.dataset_id != bq_session_dataset_id + ): + return { + "status": "ERROR", + "error_details": ( + "Protected write mode only supports SELECT statements, or write" + " operations in the anonymous dataset of a BigQuery session." + ), + } + + # Finally execute the query and fetch the result + job_config = ( + bigquery.QueryJobConfig(connection_properties=bq_connection_properties) + if bq_connection_properties + else None + ) row_iterator = bq_client.query_and_wait( - query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS + query, + job_config=job_config, + project=project_id, + max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS, ) rows = [{key: val for key, val in row.items()} for row in row_iterator] result = {"status": "SUCCESS", "rows": rows} @@ -106,9 +175,29 @@ def execute_sql( _execute_sql_write_examples = """ + Create a table with schema prescribed: + + >>> execute_sql("my_project", + ... "CREATE TABLE my_project.my_dataset.my_table " + ... "(island STRING, population INT64)") + { + "status": "SUCCESS", + "rows": [] + } + + Insert data into an existing table: + + >>> execute_sql("my_project", + ... "INSERT INTO my_project.my_dataset.my_table (island, population) " + ... "VALUES ('Dream', 124), ('Biscoe', 168)") + { + "status": "SUCCESS", + "rows": [] + } + Create a table from the result of a query: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "CREATE TABLE my_project.my_dataset.my_table AS " ... "SELECT island, COUNT(*) AS population " ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") @@ -119,7 +208,7 @@ _execute_sql_write_examples = """ Delete a table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "DROP TABLE my_project.my_dataset.my_table") { "status": "SUCCESS", @@ -128,7 +217,7 @@ _execute_sql_write_examples = """ Copy a table to another table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "CREATE TABLE my_project.my_dataset.my_table_clone " ... "CLONE my_project.my_dataset.my_table") { @@ -139,7 +228,7 @@ _execute_sql_write_examples = """ Create a snapshot (a lightweight, read-optimized copy) of en existing table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot " ... "CLONE my_project.my_dataset.my_table") { @@ -147,12 +236,214 @@ _execute_sql_write_examples = """ "rows": [] } + Create a BigQuery ML linear regression model: + + >>> execute_sql("my_project", + ... "CREATE MODEL `my_dataset.my_model` " + ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS " + ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "WHERE body_mass_g IS NOT NULL") + { + "status": "SUCCESS", + "rows": [] + } + + Evaluate BigQuery ML model: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`)") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Evaluate BigQuery ML model on custom data: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Predict using BigQuery ML model: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.PREDICT(MODEL `my_dataset.my_model`, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [ + { + "predicted_body_mass_g": "3380.9271650847013", + ... + }, { + "predicted_body_mass_g": "3873.6072435386004", + ... + }, + ... + ] + } + + Delete a BigQuery ML model: + + >>> execute_sql("my_project", "DROP MODEL `my_dataset.my_model`") + { + "status": "SUCCESS", + "rows": [] + } + Notes: - If a destination table already exists, there are a few ways to overwrite it: - Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE". - First run "DROP TABLE", followed by "CREATE TABLE". - - To insert data into a table, use "INSERT INTO" statement. + - If a model already exists, there are a few ways to overwrite it: + - Use "CREATE OR REPLACE MODEL" instead of "CREATE MODEL". + - First run "DROP MODEL", followed by "CREATE MODEL". + """ + + +_execute_sql_protecetd_write_examples = """ + Create a temporary table with schema prescribed: + + >>> execute_sql("my_project", + ... "CREATE TEMP TABLE my_table (island STRING, population INT64)") + { + "status": "SUCCESS", + "rows": [] + } + + Insert data into an existing temporary table: + + >>> execute_sql("my_project", + ... "INSERT INTO my_table (island, population) " + ... "VALUES ('Dream', 124), ('Biscoe', 168)") + { + "status": "SUCCESS", + "rows": [] + } + + Create a temporary table from the result of a query: + + >>> execute_sql("my_project", + ... "CREATE TEMP TABLE my_table AS " + ... "SELECT island, COUNT(*) AS population " + ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + { + "status": "SUCCESS", + "rows": [] + } + + Delete a temporary table: + + >>> execute_sql("my_project", "DROP TABLE my_table") + { + "status": "SUCCESS", + "rows": [] + } + + Copy a temporary table to another temporary table: + + >>> execute_sql("my_project", + ... "CREATE TEMP TABLE my_table_clone CLONE my_table") + { + "status": "SUCCESS", + "rows": [] + } + + Create a temporary BigQuery ML linear regression model: + + >>> execute_sql("my_project", + ... "CREATE TEMP MODEL my_model " + ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS" + ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "WHERE body_mass_g IS NOT NULL") + { + "status": "SUCCESS", + "rows": [] + } + + Evaluate BigQuery ML model: + + >>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL my_model)") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Evaluate BigQuery ML model on custom data: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.EVALUATE(MODEL my_model, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Predict using BigQuery ML model: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.PREDICT(MODEL my_model, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [ + { + "predicted_body_mass_g": "3380.9271650847013", + ... + }, { + "predicted_body_mass_g": "3873.6072435386004", + ... + }, + ... + ] + } + + Delete a BigQuery ML model: + + >>> execute_sql("my_project", "DROP MODEL my_model") + { + "status": "SUCCESS", + "rows": [] + } + + Notes: + - If a destination table already exists, there are a few ways to overwrite + it: + - Use "CREATE OR REPLACE TEMP TABLE" instead of "CREATE TEMP TABLE". + - First run "DROP TABLE", followed by "CREATE TEMP TABLE". + - Only temporary tables can be created, inserted into or deleted. Please + do not try creating a permanent table (non-TEMP table), inserting into or + deleting one. + - If a destination model already exists, there are a few ways to overwrite + it: + - Use "CREATE OR REPLACE TEMP MODEL" instead of "CREATE TEMP MODEL". + - First run "DROP MODEL", followed by "CREATE TEMP MODEL". + - Only temporary models can be created or deleted. Please do not try + creating a permanent model (non-TEMP model) or deleting one. """ @@ -189,6 +480,9 @@ def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]: functools.update_wrapper(execute_sql_wrapper, execute_sql) # Now, set the new docstring - execute_sql_wrapper.__doc__ += _execute_sql_write_examples + if config.write_mode == WriteMode.PROTECTED: + execute_sql_wrapper.__doc__ += _execute_sql_protecetd_write_examples + else: + execute_sql_wrapper.__doc__ += _execute_sql_write_examples return execute_sql_wrapper diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index 3cb8c3c4..c42e3881 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -25,6 +25,7 @@ from google.adk.tools.bigquery import BigQueryToolset from google.adk.tools.bigquery.config import BigQueryToolConfig from google.adk.tools.bigquery.config import WriteMode from google.adk.tools.bigquery.query_tool import execute_sql +from google.adk.tools.tool_context import ToolContext from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery from google.oauth2.credentials import Credentials @@ -80,13 +81,15 @@ async def test_execute_sql_declaration_read_only(tool_config): tool = await get_tool(tool_name, tool_config) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ - Run a BigQuery SQL query in the project and return the result. + Run a BigQuery or BigQuery ML SQL query in the project and return the result. Args: project_id (str): The GCP project id in which the query should be executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. + config (BigQueryToolConfig): The configuration for the tool. + tool_context (ToolContext): The context for the tool. Returns: dict: Dictionary representing the result of the query. @@ -97,11 +100,11 @@ async def test_execute_sql_declaration_read_only(tool_config): Examples: Fetch data or insights from a table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") { - "status": "ERROR", + "status": "SUCCESS", "rows": [ { "island": "Dream", @@ -139,13 +142,15 @@ async def test_execute_sql_declaration_write(tool_config): tool = await get_tool(tool_name, tool_config) assert tool.name == tool_name assert tool.description == textwrap.dedent("""\ - Run a BigQuery SQL query in the project and return the result. + Run a BigQuery or BigQuery ML SQL query in the project and return the result. Args: project_id (str): The GCP project id in which the query should be executed. query (str): The BigQuery SQL query to be executed. credentials (Credentials): The credentials to use for the request. + config (BigQueryToolConfig): The configuration for the tool. + tool_context (ToolContext): The context for the tool. Returns: dict: Dictionary representing the result of the query. @@ -156,11 +161,11 @@ async def test_execute_sql_declaration_write(tool_config): Examples: Fetch data or insights from a table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "SELECT island, COUNT(*) AS population " ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") { - "status": "ERROR", + "status": "SUCCESS", "rows": [ { "island": "Dream", @@ -177,9 +182,29 @@ async def test_execute_sql_declaration_write(tool_config): ] } + Create a table with schema prescribed: + + >>> execute_sql("my_project", + ... "CREATE TABLE my_project.my_dataset.my_table " + ... "(island STRING, population INT64)") + { + "status": "SUCCESS", + "rows": [] + } + + Insert data into an existing table: + + >>> execute_sql("my_project", + ... "INSERT INTO my_project.my_dataset.my_table (island, population) " + ... "VALUES ('Dream', 124), ('Biscoe', 168)") + { + "status": "SUCCESS", + "rows": [] + } + Create a table from the result of a query: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "CREATE TABLE my_project.my_dataset.my_table AS " ... "SELECT island, COUNT(*) AS population " ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") @@ -190,7 +215,7 @@ async def test_execute_sql_declaration_write(tool_config): Delete a table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "DROP TABLE my_project.my_dataset.my_table") { "status": "SUCCESS", @@ -199,7 +224,7 @@ async def test_execute_sql_declaration_write(tool_config): Copy a table to another table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "CREATE TABLE my_project.my_dataset.my_table_clone " ... "CLONE my_project.my_dataset.my_table") { @@ -210,7 +235,7 @@ async def test_execute_sql_declaration_write(tool_config): Create a snapshot (a lightweight, read-optimized copy) of en existing table: - >>> execute_sql("bigframes-dev", + >>> execute_sql("my_project", ... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot " ... "CLONE my_project.my_dataset.my_table") { @@ -218,18 +243,279 @@ async def test_execute_sql_declaration_write(tool_config): "rows": [] } + Create a BigQuery ML linear regression model: + + >>> execute_sql("my_project", + ... "CREATE MODEL `my_dataset.my_model` " + ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS " + ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "WHERE body_mass_g IS NOT NULL") + { + "status": "SUCCESS", + "rows": [] + } + + Evaluate BigQuery ML model: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`)") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Evaluate BigQuery ML model on custom data: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Predict using BigQuery ML model: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.PREDICT(MODEL `my_dataset.my_model`, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [ + { + "predicted_body_mass_g": "3380.9271650847013", + ... + }, { + "predicted_body_mass_g": "3873.6072435386004", + ... + }, + ... + ] + } + + Delete a BigQuery ML model: + + >>> execute_sql("my_project", "DROP MODEL `my_dataset.my_model`") + { + "status": "SUCCESS", + "rows": [] + } + Notes: - If a destination table already exists, there are a few ways to overwrite it: - Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE". - First run "DROP TABLE", followed by "CREATE TABLE". - - To insert data into a table, use "INSERT INTO" statement.""") + - If a model already exists, there are a few ways to overwrite it: + - Use "CREATE OR REPLACE MODEL" instead of "CREATE MODEL". + - First run "DROP MODEL", followed by "CREATE MODEL".""") + + +@pytest.mark.parametrize( + ("tool_config",), + [ + pytest.param( + BigQueryToolConfig(write_mode=WriteMode.PROTECTED), + id="explicit-protected-write", + ), + ], +) +@pytest.mark.asyncio +async def test_execute_sql_declaration_protected_write(tool_config): + """Test BigQuery execute_sql tool declaration with protected writes enabled. + + This test verifies that the execute_sql tool declaration reflects the + protected write capability. + """ + tool_name = "execute_sql" + tool = await get_tool(tool_name, tool_config) + assert tool.name == tool_name + assert tool.description == textwrap.dedent("""\ + Run a BigQuery or BigQuery ML SQL query in the project and return the result. + + Args: + project_id (str): The GCP project id in which the query should be + executed. + query (str): The BigQuery SQL query to be executed. + credentials (Credentials): The credentials to use for the request. + config (BigQueryToolConfig): The configuration for the tool. + tool_context (ToolContext): The context for the tool. + + Returns: + dict: Dictionary representing the result of the query. + If the result contains the key "result_is_likely_truncated" with + value True, it means that there may be additional rows matching the + query not returned in the result. + + Examples: + Fetch data or insights from a table: + + >>> execute_sql("my_project", + ... "SELECT island, COUNT(*) AS population " + ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + { + "status": "SUCCESS", + "rows": [ + { + "island": "Dream", + "population": 124 + }, + { + "island": "Biscoe", + "population": 168 + }, + { + "island": "Torgersen", + "population": 52 + } + ] + } + + Create a temporary table with schema prescribed: + + >>> execute_sql("my_project", + ... "CREATE TEMP TABLE my_table (island STRING, population INT64)") + { + "status": "SUCCESS", + "rows": [] + } + + Insert data into an existing temporary table: + + >>> execute_sql("my_project", + ... "INSERT INTO my_table (island, population) " + ... "VALUES ('Dream', 124), ('Biscoe', 168)") + { + "status": "SUCCESS", + "rows": [] + } + + Create a temporary table from the result of a query: + + >>> execute_sql("my_project", + ... "CREATE TEMP TABLE my_table AS " + ... "SELECT island, COUNT(*) AS population " + ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island") + { + "status": "SUCCESS", + "rows": [] + } + + Delete a temporary table: + + >>> execute_sql("my_project", "DROP TABLE my_table") + { + "status": "SUCCESS", + "rows": [] + } + + Copy a temporary table to another temporary table: + + >>> execute_sql("my_project", + ... "CREATE TEMP TABLE my_table_clone CLONE my_table") + { + "status": "SUCCESS", + "rows": [] + } + + Create a temporary BigQuery ML linear regression model: + + >>> execute_sql("my_project", + ... "CREATE TEMP MODEL my_model " + ... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS" + ... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` " + ... "WHERE body_mass_g IS NOT NULL") + { + "status": "SUCCESS", + "rows": [] + } + + Evaluate BigQuery ML model: + + >>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL my_model)") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Evaluate BigQuery ML model on custom data: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.EVALUATE(MODEL my_model, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [{'mean_absolute_error': 227.01223667447218, + 'mean_squared_error': 81838.15989216768, + 'mean_squared_log_error': 0.0050704473735013, + 'median_absolute_error': 173.08081641661738, + 'r2_score': 0.8723772534253441, + 'explained_variance': 0.8723772534253442}] + } + + Predict using BigQuery ML model: + + >>> execute_sql("my_project", + ... "SELECT * FROM ML.PREDICT(MODEL my_model, " + ... "(SELECT * FROM `my_dataset.my_table`))") + { + "status": "SUCCESS", + "rows": [ + { + "predicted_body_mass_g": "3380.9271650847013", + ... + }, { + "predicted_body_mass_g": "3873.6072435386004", + ... + }, + ... + ] + } + + Delete a BigQuery ML model: + + >>> execute_sql("my_project", "DROP MODEL my_model") + { + "status": "SUCCESS", + "rows": [] + } + + Notes: + - If a destination table already exists, there are a few ways to overwrite + it: + - Use "CREATE OR REPLACE TEMP TABLE" instead of "CREATE TEMP TABLE". + - First run "DROP TABLE", followed by "CREATE TEMP TABLE". + - Only temporary tables can be created, inserted into or deleted. Please + do not try creating a permanent table (non-TEMP table), inserting into or + deleting one. + - If a destination model already exists, there are a few ways to overwrite + it: + - Use "CREATE OR REPLACE TEMP MODEL" instead of "CREATE TEMP MODEL". + - First run "DROP MODEL", followed by "CREATE TEMP MODEL". + - Only temporary models can be created or deleted. Please do not try + creating a permanent model (non-TEMP model) or deleting one.""") @pytest.mark.parametrize( ("write_mode",), [ pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.PROTECTED, id="protected"), pytest.param(WriteMode.ALLOWED, id="allowed"), ], ) @@ -241,6 +527,11 @@ def test_execute_sql_select_stmt(write_mode): query_result = [{"num": 123}] credentials = mock.create_autospec(Credentials, instance=True) tool_config = BigQueryToolConfig(write_mode=write_mode) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: # The mock instance @@ -255,7 +546,7 @@ def test_execute_sql_select_stmt(write_mode): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config) + result = execute_sql(project, query, credentials, tool_config, tool_context) assert result == {"status": "SUCCESS", "rows": query_result} @@ -272,6 +563,18 @@ def test_execute_sql_select_stmt(write_mode): "DROP_TABLE", id="drop-table", ), + pytest.param( + "CREATE MODEL my_dataset.my_model (model_type='linear_reg'," + " input_label_cols=['label_col']) AS SELECT * FROM" + " my_dataset.my_table", + "CREATE_MODEL", + id="create-model", + ), + pytest.param( + "DROP MODEL my_dataset.my_model", + "DROP_MODEL", + id="drop-model", + ), ], ) def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): @@ -280,6 +583,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): query_result = [] credentials = mock.create_autospec(Credentials, instance=True) tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED) + tool_context = mock.create_autospec(ToolContext, instance=True) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: # The mock instance @@ -294,7 +598,7 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config) + result = execute_sql(project, query, credentials, tool_config, tool_context) assert result == {"status": "SUCCESS", "rows": query_result} @@ -311,6 +615,18 @@ def test_execute_sql_non_select_stmt_write_allowed(query, statement_type): "DROP_TABLE", id="drop-table", ), + pytest.param( + "CREATE MODEL my_dataset.my_model (model_type='linear_reg'," + " input_label_cols=['label_col']) AS SELECT * FROM" + " my_dataset.my_table", + "CREATE_MODEL", + id="create-model", + ), + pytest.param( + "DROP MODEL my_dataset.my_model", + "DROP_MODEL", + id="drop-model", + ), ], ) def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): @@ -319,6 +635,7 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): query_result = [] credentials = mock.create_autospec(Credentials, instance=True) tool_config = BigQueryToolConfig(write_mode=WriteMode.BLOCKED) + tool_context = mock.create_autospec(ToolContext, instance=True) with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: # The mock instance @@ -333,17 +650,144 @@ def test_execute_sql_non_select_stmt_write_blocked(query, statement_type): bq_client.query_and_wait.return_value = query_result # Test the tool - result = execute_sql(project, query, credentials, tool_config) + result = execute_sql(project, query, credentials, tool_config, tool_context) assert result == { "status": "ERROR", "error_details": "Read-only mode only supports SELECT statements.", } +@pytest.mark.parametrize( + ("query", "statement_type"), + [ + pytest.param( + "CREATE TEMP TABLE my_table AS SELECT 123 AS num", + "CREATE_AS_SELECT", + id="create-as-select", + ), + pytest.param( + "DROP TABLE my_table", + "DROP_TABLE", + id="drop-table", + ), + pytest.param( + "CREATE TEMP MODEL my_model (model_type='linear_reg'," + " input_label_cols=['label_col']) AS SELECT * FROM" + " my_dataset.my_table", + "CREATE_MODEL", + id="create-model", + ), + pytest.param( + "DROP MODEL my_model", + "DROP_MODEL", + id="drop-model", + ), + ], +) +def test_execute_sql_non_select_stmt_write_protected(query, statement_type): + """Test execute_sql tool for non-SELECT query when writes are protected.""" + project = "my_project" + query_result = [] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + query_job.destination.dataset_id = "_anonymous_dataset" + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config, tool_context) + assert result == {"status": "SUCCESS", "rows": query_result} + + +@pytest.mark.parametrize( + ("query", "statement_type"), + [ + pytest.param( + "CREATE TABLE my_dataset.my_table AS SELECT 123 AS num", + "CREATE_AS_SELECT", + id="create-as-select", + ), + pytest.param( + "DROP TABLE my_dataset.my_table", + "DROP_TABLE", + id="drop-table", + ), + pytest.param( + "CREATE MODEL my_dataset.my_model (model_type='linear_reg'," + " input_label_cols=['label_col']) AS SELECT * FROM" + " my_dataset.my_table", + "CREATE_MODEL", + id="create-model", + ), + pytest.param( + "DROP MODEL my_dataset.my_model", + "DROP_MODEL", + id="drop-model", + ), + ], +) +def test_execute_sql_non_select_stmt_write_protected_persistent_target( + query, statement_type +): + """Test execute_sql tool for non-SELECT query when writes are protected. + + This is a special case when the destination table is a persistent/permananent + one and the protected write is enabled. In this case the operation should fail. + """ + project = "my_project" + query_result = [] + credentials = mock.create_autospec(Credentials, instance=True) + tool_config = BigQueryToolConfig(write_mode=WriteMode.PROTECTED) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + query_job.destination.dataset_id = "my_dataset" + bq_client.query.return_value = query_job + + # Simulate the result of query_and_wait API + bq_client.query_and_wait.return_value = query_result + + # Test the tool + result = execute_sql(project, query, credentials, tool_config, tool_context) + assert result == { + "status": "ERROR", + "error_details": ( + "Protected write mode only supports SELECT statements, or write" + " operations in the anonymous dataset of a BigQuery session." + ), + } + + @pytest.mark.parametrize( ("write_mode",), [ pytest.param(WriteMode.BLOCKED, id="blocked"), + pytest.param(WriteMode.PROTECTED, id="protected"), pytest.param(WriteMode.ALLOWED, id="allowed"), ], ) @@ -361,6 +805,11 @@ def test_execute_sql_no_default_auth( query_result = [{"num": 123}] credentials = mock.create_autospec(Credentials, instance=True) tool_config = BigQueryToolConfig(write_mode=write_mode) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) # Simulate the behavior of default auth - on purpose throw exception when # the default auth is called @@ -377,6 +826,6 @@ def test_execute_sql_no_default_auth( mock_query_and_wait.return_value = query_result # Test the tool worked without invoking default auth - result = execute_sql(project, query, credentials, tool_config) + result = execute_sql(project, query, credentials, tool_config, tool_context) assert result == {"status": "SUCCESS", "rows": query_result} mock_default_auth.assert_not_called()