You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support protected write in BigQuery execute_sql tool
This change adds a new enum value which the agent builder can pass in the `BigQueryToolConfig` to allow limited writes to the `execute_sql` tool. PiperOrigin-RevId: 776661744
This commit is contained in:
committed by
Copybara-Service
parent
e2748b3ed5
commit
dc43d518c9
@@ -26,7 +26,11 @@ import google.auth
|
||||
CREDENTIALS_TYPE = AuthCredentialTypes.OAUTH2
|
||||
|
||||
|
||||
# Define BigQuery tool config
|
||||
# Define BigQuery tool config with write mode set to allowed. Note that this is
|
||||
# only to demonstrate the full capability of the BigQuery tools. In production
|
||||
# you may want to change to BLOCKED (default write mode, effectively makes the
|
||||
# tool read-only) or PROTECTED (only allows writes in the anonymous dataset of a
|
||||
# BigQuery session) write mode.
|
||||
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
|
||||
|
||||
if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2:
|
||||
|
||||
@@ -26,10 +26,20 @@ class WriteMode(Enum):
|
||||
|
||||
BLOCKED = 'blocked'
|
||||
"""No write operations are allowed.
|
||||
|
||||
|
||||
This mode implies that only read (i.e. SELECT query) operations are allowed.
|
||||
"""
|
||||
|
||||
PROTECTED = 'protected'
|
||||
"""Only protected write operations are allowed in a BigQuery session.
|
||||
|
||||
In this mode write operations in the anonymous dataset of a BigQuery session
|
||||
are allowed. For example, a temporaray table can be created, manipulated and
|
||||
deleted in the anonymous dataset during Agent interaction, while protecting
|
||||
permanent tables from being modified or deleted. To learn more about BigQuery
|
||||
sessions, see https://cloud.google.com/bigquery/docs/sessions-intro.
|
||||
"""
|
||||
|
||||
ALLOWED = 'allowed'
|
||||
"""All write operations are allowed."""
|
||||
|
||||
|
||||
@@ -20,10 +20,12 @@ from google.auth.credentials import Credentials
|
||||
from google.cloud import bigquery
|
||||
|
||||
from . import client
|
||||
from ..tool_context import ToolContext
|
||||
from .config import BigQueryToolConfig
|
||||
from .config import WriteMode
|
||||
|
||||
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
|
||||
BIGQUERY_SESSION_INFO_KEY = "bigquery_session_info"
|
||||
|
||||
|
||||
def execute_sql(
|
||||
@@ -31,14 +33,17 @@ def execute_sql(
|
||||
query: str,
|
||||
credentials: Credentials,
|
||||
config: BigQueryToolConfig,
|
||||
tool_context: ToolContext,
|
||||
) -> dict:
|
||||
"""Run a BigQuery SQL query in the project and return the result.
|
||||
"""Run a BigQuery or BigQuery ML SQL query in the project and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the query should be
|
||||
executed.
|
||||
query (str): The BigQuery SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
config (BigQueryToolConfig): The configuration for the tool.
|
||||
tool_context (ToolContext): The context for the tool.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the result of the query.
|
||||
@@ -49,11 +54,11 @@ def execute_sql(
|
||||
Examples:
|
||||
Fetch data or insights from a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
>>> execute_sql("my_project",
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "ERROR",
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"island": "Dream",
|
||||
@@ -72,23 +77,87 @@ def execute_sql(
|
||||
"""
|
||||
|
||||
try:
|
||||
# Get BigQuery client
|
||||
bq_client = client.get_bigquery_client(
|
||||
project=project_id, credentials=credentials
|
||||
)
|
||||
|
||||
# BigQuery connection properties where applicable
|
||||
bq_connection_properties = None
|
||||
|
||||
if not config or config.write_mode == WriteMode.BLOCKED:
|
||||
query_job = bq_client.query(
|
||||
dry_run_query_job = bq_client.query(
|
||||
query,
|
||||
project=project_id,
|
||||
job_config=bigquery.QueryJobConfig(dry_run=True),
|
||||
)
|
||||
if query_job.statement_type != "SELECT":
|
||||
if dry_run_query_job.statement_type != "SELECT":
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": "Read-only mode only supports SELECT statements.",
|
||||
}
|
||||
elif config.write_mode == WriteMode.PROTECTED:
|
||||
# In protected write mode, write operation only to a temporary artifact is
|
||||
# allowed. This artifact must have been created in a BigQuery session. In
|
||||
# such a scenario the session info (session id and the anonymous dataset
|
||||
# containing the artifact) is persisted in the tool context.
|
||||
bq_session_info = tool_context.state.get(BIGQUERY_SESSION_INFO_KEY, None)
|
||||
if bq_session_info:
|
||||
bq_session_id, bq_session_dataset_id = bq_session_info
|
||||
else:
|
||||
session_creator_job = bq_client.query(
|
||||
"SELECT 1",
|
||||
project=project_id,
|
||||
job_config=bigquery.QueryJobConfig(
|
||||
dry_run=True, create_session=True
|
||||
),
|
||||
)
|
||||
bq_session_id = session_creator_job.session_info.session_id
|
||||
bq_session_dataset_id = session_creator_job.destination.dataset_id
|
||||
|
||||
# Remember the BigQuery session info for subsequent queries
|
||||
tool_context.state[BIGQUERY_SESSION_INFO_KEY] = (
|
||||
bq_session_id,
|
||||
bq_session_dataset_id,
|
||||
)
|
||||
|
||||
# Session connection property will be set in the query execution
|
||||
bq_connection_properties = [
|
||||
bigquery.ConnectionProperty("session_id", bq_session_id)
|
||||
]
|
||||
|
||||
# Check the query type w.r.t. the BigQuery session
|
||||
dry_run_query_job = bq_client.query(
|
||||
query,
|
||||
project=project_id,
|
||||
job_config=bigquery.QueryJobConfig(
|
||||
dry_run=True,
|
||||
connection_properties=bq_connection_properties,
|
||||
),
|
||||
)
|
||||
if (
|
||||
dry_run_query_job.statement_type != "SELECT"
|
||||
and dry_run_query_job.destination.dataset_id != bq_session_dataset_id
|
||||
):
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": (
|
||||
"Protected write mode only supports SELECT statements, or write"
|
||||
" operations in the anonymous dataset of a BigQuery session."
|
||||
),
|
||||
}
|
||||
|
||||
# Finally execute the query and fetch the result
|
||||
job_config = (
|
||||
bigquery.QueryJobConfig(connection_properties=bq_connection_properties)
|
||||
if bq_connection_properties
|
||||
else None
|
||||
)
|
||||
row_iterator = bq_client.query_and_wait(
|
||||
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
||||
query,
|
||||
job_config=job_config,
|
||||
project=project_id,
|
||||
max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS,
|
||||
)
|
||||
rows = [{key: val for key, val in row.items()} for row in row_iterator]
|
||||
result = {"status": "SUCCESS", "rows": rows}
|
||||
@@ -106,9 +175,29 @@ def execute_sql(
|
||||
|
||||
|
||||
_execute_sql_write_examples = """
|
||||
Create a table with schema prescribed:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table "
|
||||
... "(island STRING, population INT64)")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Insert data into an existing table:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "INSERT INTO my_project.my_dataset.my_table (island, population) "
|
||||
... "VALUES ('Dream', 124), ('Biscoe', 168)")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Create a table from the result of a query:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table AS "
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
@@ -119,7 +208,7 @@ _execute_sql_write_examples = """
|
||||
|
||||
Delete a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
>>> execute_sql("my_project",
|
||||
... "DROP TABLE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
@@ -128,7 +217,7 @@ _execute_sql_write_examples = """
|
||||
|
||||
Copy a table to another table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table_clone "
|
||||
... "CLONE my_project.my_dataset.my_table")
|
||||
{
|
||||
@@ -139,7 +228,7 @@ _execute_sql_write_examples = """
|
||||
Create a snapshot (a lightweight, read-optimized copy) of en existing
|
||||
table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
|
||||
... "CLONE my_project.my_dataset.my_table")
|
||||
{
|
||||
@@ -147,12 +236,214 @@ _execute_sql_write_examples = """
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Create a BigQuery ML linear regression model:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE MODEL `my_dataset.my_model` "
|
||||
... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS "
|
||||
... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` "
|
||||
... "WHERE body_mass_g IS NOT NULL")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Evaluate BigQuery ML model:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`)")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [{'mean_absolute_error': 227.01223667447218,
|
||||
'mean_squared_error': 81838.15989216768,
|
||||
'mean_squared_log_error': 0.0050704473735013,
|
||||
'median_absolute_error': 173.08081641661738,
|
||||
'r2_score': 0.8723772534253441,
|
||||
'explained_variance': 0.8723772534253442}]
|
||||
}
|
||||
|
||||
Evaluate BigQuery ML model on custom data:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "SELECT * FROM ML.EVALUATE(MODEL `my_dataset.my_model`, "
|
||||
... "(SELECT * FROM `my_dataset.my_table`))")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [{'mean_absolute_error': 227.01223667447218,
|
||||
'mean_squared_error': 81838.15989216768,
|
||||
'mean_squared_log_error': 0.0050704473735013,
|
||||
'median_absolute_error': 173.08081641661738,
|
||||
'r2_score': 0.8723772534253441,
|
||||
'explained_variance': 0.8723772534253442}]
|
||||
}
|
||||
|
||||
Predict using BigQuery ML model:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "SELECT * FROM ML.PREDICT(MODEL `my_dataset.my_model`, "
|
||||
... "(SELECT * FROM `my_dataset.my_table`))")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"predicted_body_mass_g": "3380.9271650847013",
|
||||
...
|
||||
}, {
|
||||
"predicted_body_mass_g": "3873.6072435386004",
|
||||
...
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Delete a BigQuery ML model:
|
||||
|
||||
>>> execute_sql("my_project", "DROP MODEL `my_dataset.my_model`")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Notes:
|
||||
- If a destination table already exists, there are a few ways to overwrite
|
||||
it:
|
||||
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
|
||||
- First run "DROP TABLE", followed by "CREATE TABLE".
|
||||
- To insert data into a table, use "INSERT INTO" statement.
|
||||
- If a model already exists, there are a few ways to overwrite it:
|
||||
- Use "CREATE OR REPLACE MODEL" instead of "CREATE MODEL".
|
||||
- First run "DROP MODEL", followed by "CREATE MODEL".
|
||||
"""
|
||||
|
||||
|
||||
_execute_sql_protecetd_write_examples = """
|
||||
Create a temporary table with schema prescribed:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TEMP TABLE my_table (island STRING, population INT64)")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Insert data into an existing temporary table:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "INSERT INTO my_table (island, population) "
|
||||
... "VALUES ('Dream', 124), ('Biscoe', 168)")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Create a temporary table from the result of a query:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TEMP TABLE my_table AS "
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Delete a temporary table:
|
||||
|
||||
>>> execute_sql("my_project", "DROP TABLE my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Copy a temporary table to another temporary table:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TEMP TABLE my_table_clone CLONE my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Create a temporary BigQuery ML linear regression model:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "CREATE TEMP MODEL my_model "
|
||||
... "OPTIONS (model_type='linear_reg', input_label_cols=['body_mass_g']) AS"
|
||||
... "SELECT * FROM `bigquery-public-data.ml_datasets.penguins` "
|
||||
... "WHERE body_mass_g IS NOT NULL")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Evaluate BigQuery ML model:
|
||||
|
||||
>>> execute_sql("my_project", "SELECT * FROM ML.EVALUATE(MODEL my_model)")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [{'mean_absolute_error': 227.01223667447218,
|
||||
'mean_squared_error': 81838.15989216768,
|
||||
'mean_squared_log_error': 0.0050704473735013,
|
||||
'median_absolute_error': 173.08081641661738,
|
||||
'r2_score': 0.8723772534253441,
|
||||
'explained_variance': 0.8723772534253442}]
|
||||
}
|
||||
|
||||
Evaluate BigQuery ML model on custom data:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "SELECT * FROM ML.EVALUATE(MODEL my_model, "
|
||||
... "(SELECT * FROM `my_dataset.my_table`))")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [{'mean_absolute_error': 227.01223667447218,
|
||||
'mean_squared_error': 81838.15989216768,
|
||||
'mean_squared_log_error': 0.0050704473735013,
|
||||
'median_absolute_error': 173.08081641661738,
|
||||
'r2_score': 0.8723772534253441,
|
||||
'explained_variance': 0.8723772534253442}]
|
||||
}
|
||||
|
||||
Predict using BigQuery ML model:
|
||||
|
||||
>>> execute_sql("my_project",
|
||||
... "SELECT * FROM ML.PREDICT(MODEL my_model, "
|
||||
... "(SELECT * FROM `my_dataset.my_table`))")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": [
|
||||
{
|
||||
"predicted_body_mass_g": "3380.9271650847013",
|
||||
...
|
||||
}, {
|
||||
"predicted_body_mass_g": "3873.6072435386004",
|
||||
...
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Delete a BigQuery ML model:
|
||||
|
||||
>>> execute_sql("my_project", "DROP MODEL my_model")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Notes:
|
||||
- If a destination table already exists, there are a few ways to overwrite
|
||||
it:
|
||||
- Use "CREATE OR REPLACE TEMP TABLE" instead of "CREATE TEMP TABLE".
|
||||
- First run "DROP TABLE", followed by "CREATE TEMP TABLE".
|
||||
- Only temporary tables can be created, inserted into or deleted. Please
|
||||
do not try creating a permanent table (non-TEMP table), inserting into or
|
||||
deleting one.
|
||||
- If a destination model already exists, there are a few ways to overwrite
|
||||
it:
|
||||
- Use "CREATE OR REPLACE TEMP MODEL" instead of "CREATE TEMP MODEL".
|
||||
- First run "DROP MODEL", followed by "CREATE TEMP MODEL".
|
||||
- Only temporary models can be created or deleted. Please do not try
|
||||
creating a permanent model (non-TEMP model) or deleting one.
|
||||
"""
|
||||
|
||||
|
||||
@@ -189,6 +480,9 @@ def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
|
||||
functools.update_wrapper(execute_sql_wrapper, execute_sql)
|
||||
|
||||
# Now, set the new docstring
|
||||
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
|
||||
if config.write_mode == WriteMode.PROTECTED:
|
||||
execute_sql_wrapper.__doc__ += _execute_sql_protecetd_write_examples
|
||||
else:
|
||||
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
|
||||
|
||||
return execute_sql_wrapper
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user