feat: Support protected write in BigQuery execute_sql tool

This change adds a new enum value which the agent builder can pass in the `BigQueryToolConfig` to allow limited writes to the `execute_sql` tool.

PiperOrigin-RevId: 776661744
This commit is contained in:
Google Team Member
2025-06-27 11:42:40 -07:00
committed by Copybara-Service
parent e2748b3ed5
commit dc43d518c9
4 changed files with 786 additions and 29 deletions
+5 -1
View File
@@ -26,7 +26,11 @@ import google.auth
CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2
# Define BigQuery tool config
# Define BigQuery tool config with write mode set to allowed. Note that this is
# only to demonstrate the full capability of the BigQuery tools. In production
# you may want to change to BLOCKED (default write mode, effectively makes the
# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a
# BigQuery session) write mode.
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
+11 -1
View File
@@ -26,10 +26,20 @@ class WriteMode(Enum):
BLOCKED = 'blocked'
"""No write operations are allowed.
This mode implies that only read (i.e. SELECT query) operations are allowed.
"""
PROTECTED = 'protected'
"""Only protected write operations are allowed in a BigQuery session.
In this mode write operations in the anonymous dataset of a BigQuery session
are allowed. For example, a temporaray table can be created, manipulated and
deleted in the anonymous dataset during Agent interaction, while protecting
permanent tables from being modified or deleted. To learn more about BigQuery
sessions, see https://cloud.google.com/bigquery/docs/sessions-intro.
"""
ALLOWED = 'allowed'
"""All write operations are allowed."""
+306 -12
View File
@@ -20,10 +20,12 @@ from google.auth.credentials import Credentials
from google.cloud import bigquery
from . import client
from ..tool_context import ToolContext
from .config import BigQueryToolConfig
from .config import WriteMode
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
BIGQUERY_SESSION_INFO_KEY = "bigquery_session_info"
def execute_sql(
@@ -31,14 +33,17 @@ def execute_sql(
query: str,
credentials: Credentials,
config: BigQueryToolConfig,
tool_context: ToolContext,
) -> dict:
"""Run a BigQuery SQL query in the project and return the result.
"""Run a BigQuery or BigQuery ML SQL query in the project and return the result.
Args:
project_id (str): The GCP project id in which the query should be
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
config (BigQueryToolConfig): The configuration for the tool.
tool_context (ToolContext): The context for the tool.
Returns:
dict: Dictionary representing the result of the query.
@@ -49,11 +54,11 @@ def execute_sql(
Examples:
Fetch data or insights from a table:
>>> execute_sql("bigframes-dev",
>>> execute_sql("my_project",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "ERROR",
"status": "SUCCESS",
"rows": [
{
"island": "Dream",
@@ -72,23 +77,87 @@ def execute_sql(
"""
try:
# Get BigQuery client
bq_client = client.get_bigquery_client(
project=project_id, credentials=credentials
)
# BigQuery connection properties where applicable
bq_connection_properties = None
if not config or config.write_mode == WriteMode.BLOCKED:
query_job = bq_client.query(
dry_run_query_job = bq_client.query(
query,
project=project_id,
job_config=bigquery.QueryJobConfig(dry_run=True),
)
if query_job.statement_type != "SELECT":
if dry_run_query_job.statement_type != "SELECT":
return {
"status": "ERROR",
"error_details": "Read-only mode only supports SELECT statements.",
}
elif config.write_mode == WriteMode.PROTECTED:
# In protected write mode, write operation only to a temporary artifact is
# allowed. This artifact must have been created in a BigQuery session. In
# such a scenario the session info (session id and the anonymous dataset
# containing the artifact) is persisted in the tool context.
bq_session_info = tool_context.state.get(BIGQUERY_SESSION_INFO_KEY, None)
if bq_session_info:
bq_session_id, bq_session_dataset_id = bq_session_info
else:
session_creator_job = bq_client.query(
"SELECT 1",
project=project_id,
job_config=bigquery.QueryJobConfig(
dry_run=True, create_session=True
),
)
bq_session_id = session_creator_job.session_info.session_id
bq_session_dataset_id = session_creator_job.destination.dataset_id
# Remember the BigQuery session info for subsequent queries
tool_context.state[BIGQUERY_SESSION_INFO_KEY] = (
bq_session_id,
bq_session_dataset_id,
)
# Session connection property will be set in the query execution
bq_connection_properties = [
bigquery.ConnectionProperty("session_id", bq_session_id)
]
# Check the query type w.r.t. the BigQuery session
dry_run_query_job = bq_client.query(
query,
project=project_id,
job_config=bigquery.QueryJobConfig(
dry_run=True,
connection_properties=bq_connection_properties,
),
)
if (
dry_run_query_job.statement_type != "SELECT"
and dry_run_query_job.destination.dataset_id != bq_session_dataset_id
):
return {
"status": "ERROR",
"error_details": (
"Protected write mode only supports SELECT statements, or write"
" operations in the anonymous dataset of a BigQuery session."
),
}
# Finally execute the query and fetch the result
job_config = (
bigquery.QueryJobConfig(connection_properties=bq_connection_properties)
if bq_connection_properties
else None
)
row_iterator = bq_client.query_and_wait(
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
query,
job_config=job_config,
project=project_id,
max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS,
)
rows = [{key: val for key, val in row.items()} for row in row_iterator]
result = {"status": "SUCCESS", "rows": rows}
@@ -106,9 +175,29 @@ def execute_sql(
_execute_sql_write_examples = """
Create a table with schema prescribed:
>>> execute_sql("my_project",
... "CREATE TABLE my_project.my_dataset.my_table "
... "(island STRING, population INT64)")
{
"status": "SUCCESS",
"rows": []
}
Insert data into an existing table:
>>> execute_sql("my_project",
... "INSERT INTO my_project.my_dataset.my_table (island, population) "
... "VALUES ('Dream', 124), ('Biscoe', 168)")
{
"status": "SUCCESS",
"rows": []
}
Create a table from the result of a query:
>>> execute_sql("bigframes-dev",
>>> execute_sql("my_project",
... "CREATE TABLE my_project.my_dataset.my_table AS "
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
@@ -119,7 +208,7 @@ _execute_sql_write_examples = """
Delete a table:
>>> execute_sql("bigframes-dev",
>>> execute_sql("my_project",
... "DROP TABLE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
@@ -128,7 +217,7 @@ _execute_sql_write_examples = """
Copy a table to another table:
>>> execute_sql("bigframes-dev",
>>> execute_sql("my_project",
... "CREATE TABLE my_project.my_dataset.my_table_clone "
... "CLONE my_project.my_dataset.my_table")
{
@@ -139,7 +228,7 @@ _execute_sql_write_examples = """
Create a snapshot (a lightweight, read-optimized copy) of en existing
table:
>>> execute_sql("bigframes-dev",
>>> execute_sql("my_project",
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
... "CLONE my_project.my_dataset.my_table")
{
@@ -147,12 +236,214 @@ _execute_sql_write_examples = """
"rows": []
}
Create a BigQuery ML linear regression model:
>>> execute_sql("my_project",
... "CREATE MODEL `my_dataset.my_model` "
... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS "
... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` "
... "WHERE body_mass_g IS NOT NULL")
{
"status": "SUCCESS",
"rows": []
}
Evaluate BigQuery ML model:
>>> execute_sql("my_project",
... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`)")
{
"status": "SUCCESS",
"rows": [{'mean_absolute_error': 227.01223667447218,
'mean_squared_error': 81838.15989216768,
'mean_squared_log_error': 0.0050704473735013,
'median_absolute_error': 173.08081641661738,
'r2_score': 0.8723772534253441,
'explained_variance': 0.8723772534253442}]
}
Evaluate BigQuery ML model on custom data:
>>> execute_sql("my_project",
... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`, "
... "(SELECT * FROM `my_dataset.my_table`))")
{
"status": "SUCCESS",
"rows": [{'mean_absolute_error': 227.01223667447218,
'mean_squared_error': 81838.15989216768,
'mean_squared_log_error': 0.0050704473735013,
'median_absolute_error': 173.08081641661738,
'r2_score': 0.8723772534253441,
'explained_variance': 0.8723772534253442}]
}
Predict using BigQuery ML model:
>>> execute_sql("my_project",
... "SELECT * FROM ML.PREDICT(MODEL `my_dataset.my_model`, "
... "(SELECT * FROM `my_dataset.my_table`))")
{
"status": "SUCCESS",
"rows": [
{
"predicted_body_mass_g": "3380.9271650847013",
...
}, {
"predicted_body_mass_g": "3873.6072435386004",
...
},
...
]
}
Delete a BigQuery ML model:
>>> execute_sql("my_project", "DROP MODEL `my_dataset.my_model`")
{
"status": "SUCCESS",
"rows": []
}
Notes:
- If a destination table already exists, there are a few ways to overwrite
it:
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
- First run "DROP TABLE", followed by "CREATE TABLE".
- To insert data into a table, use "INSERT INTO" statement.
- If a model already exists, there are a few ways to overwrite it:
- Use "CREATE OR REPLACE MODEL" instead of "CREATE MODEL".
- First run "DROP MODEL", followed by "CREATE MODEL".
"""
_execute_sql_protecetd_write_examples = """
Create a temporary table with schema prescribed:
>>> execute_sql("my_project",
... "CREATE TEMP TABLE my_table (island STRING, population INT64)")
{
"status": "SUCCESS",
"rows": []
}
Insert data into an existing temporary table:
>>> execute_sql("my_project",
... "INSERT INTO my_table (island, population) "
... "VALUES ('Dream', 124), ('Biscoe', 168)")
{
"status": "SUCCESS",
"rows": []
}
Create a temporary table from the result of a query:
>>> execute_sql("my_project",
... "CREATE TEMP TABLE my_table AS "
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "SUCCESS",
"rows": []
}
Delete a temporary table:
>>> execute_sql("my_project", "DROP TABLE my_table")
{
"status": "SUCCESS",
"rows": []
}
Copy a temporary table to another temporary table:
>>> execute_sql("my_project",
... "CREATE TEMP TABLE my_table_clone CLONE my_table")
{
"status": "SUCCESS",
"rows": []
}
Create a temporary BigQuery ML linear regression model:
>>> execute_sql("my_project",
... "CREATE TEMP MODEL my_model "
... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS"
... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` "
... "WHERE body_mass_g IS NOT NULL")
{
"status": "SUCCESS",
"rows": []
}
Evaluate BigQuery ML model:
>>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL my_model)")
{
"status": "SUCCESS",
"rows": [{'mean_absolute_error': 227.01223667447218,
'mean_squared_error': 81838.15989216768,
'mean_squared_log_error': 0.0050704473735013,
'median_absolute_error': 173.08081641661738,
'r2_score': 0.8723772534253441,
'explained_variance': 0.8723772534253442}]
}
Evaluate BigQuery ML model on custom data:
>>> execute_sql("my_project",
... "SELECT * FROM ML.EVALUATE(MODEL my_model, "
... "(SELECT * FROM `my_dataset.my_table`))")
{
"status": "SUCCESS",
"rows": [{'mean_absolute_error': 227.01223667447218,
'mean_squared_error': 81838.15989216768,
'mean_squared_log_error': 0.0050704473735013,
'median_absolute_error': 173.08081641661738,
'r2_score': 0.8723772534253441,
'explained_variance': 0.8723772534253442}]
}
Predict using BigQuery ML model:
>>> execute_sql("my_project",
... "SELECT * FROM ML.PREDICT(MODEL my_model, "
... "(SELECT * FROM `my_dataset.my_table`))")
{
"status": "SUCCESS",
"rows": [
{
"predicted_body_mass_g": "3380.9271650847013",
...
}, {
"predicted_body_mass_g": "3873.6072435386004",
...
},
...
]
}
Delete a BigQuery ML model:
>>> execute_sql("my_project", "DROP MODEL my_model")
{
"status": "SUCCESS",
"rows": []
}
Notes:
- If a destination table already exists, there are a few ways to overwrite
it:
- Use "CREATE OR REPLACE TEMP TABLE" instead of "CREATE TEMP TABLE".
- First run "DROP TABLE", followed by "CREATE TEMP TABLE".
- Only temporary tables can be created, inserted into or deleted. Please
do not try creating a permanent table (non-TEMP table), inserting into or
deleting one.
- If a destination model already exists, there are a few ways to overwrite
it:
- Use "CREATE OR REPLACE TEMP MODEL" instead of "CREATE TEMP MODEL".
- First run "DROP MODEL", followed by "CREATE TEMP MODEL".
- Only temporary models can be created or deleted. Please do not try
creating a permanent model (non-TEMP model) or deleting one.
"""
@@ -189,6 +480,9 @@ def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
functools.update_wrapper(execute_sql_wrapper, execute_sql)
# Now, set the new docstring
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
if config.write_mode == WriteMode.PROTECTED:
execute_sql_wrapper.__doc__ += _execute_sql_protecetd_write_examples
else:
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
return execute_sql_wrapper
File diff suppressed because it is too large Load Diff