feat: Introduce write protected mode to BigQuery tools

This allows to protect against any write operations (e.g. update or delete a table), useful for some agents that must only be used in a read-only mode, while the user may have write permissions.

PiperOrigin-RevId: 769803741
This commit is contained in:
Google Team Member
2025-06-10 14:36:42 -07:00
committed by Copybara-Service
parent 77f44a4e45
commit 6c999caa41
10 changed files with 490 additions and 51 deletions
+7 -1
View File
@@ -17,11 +17,15 @@ import os
from google.adk.agents import llm_agent
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
import google.auth
RUN_WITH_ADC = False
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
if RUN_WITH_ADC:
# Initialize the tools to use the application default credentials.
application_default_credentials, _ = google.auth.default()
@@ -37,7 +41,9 @@ else:
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
)
bigquery_toolset = BigQueryToolset(credentials_config=credentials_config)
bigquery_toolset = BigQueryToolset(
credentials_config=credentials_config, bigquery_tool_config=tool_config
)
# The variable name `root_agent` determines what your root agent is for the
# debug CLI
+19 -9
View File
@@ -25,6 +25,7 @@ from ..function_tool import FunctionTool
from ..tool_context import ToolContext
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_credentials import BigQueryCredentialsManager
from .config import BigQueryToolConfig
class BigQueryTool(FunctionTool):
@@ -41,21 +42,27 @@ class BigQueryTool(FunctionTool):
def __init__(
self,
func: Callable[..., Any],
credentials: Optional[BigQueryCredentialsConfig] = None,
*,
credentials_config: Optional[BigQueryCredentialsConfig] = None,
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
):
"""Initialize the Google API tool.
Args:
func: callable that impelments the tool's logic, can accept one
'credential" parameter
credentials: credentials used to call Google API. If None, then we don't
hanlde the auth logic
credentials_config: credentials config used to call Google API. If None,
then we don't hanlde the auth logic
"""
super().__init__(func=func)
self._ignore_params.append("credentials")
self.credentials_manager = (
BigQueryCredentialsManager(credentials) if credentials else None
self._ignore_params.append("config")
self._credentials_manager = (
BigQueryCredentialsManager(credentials_config)
if credentials_config
else None
)
self._tool_config = bigquery_tool_config
@override
async def run_async(
@@ -69,12 +76,12 @@ class BigQueryTool(FunctionTool):
try:
# Get valid credentials
credentials = (
await self.credentials_manager.get_valid_credentials(tool_context)
if self.credentials_manager
await self._credentials_manager.get_valid_credentials(tool_context)
if self._credentials_manager
else None
)
if credentials is None and self.credentials_manager:
if credentials is None and self._credentials_manager:
# OAuth flow in progress
return (
"User authorization is required to access Google services for"
@@ -84,7 +91,7 @@ class BigQueryTool(FunctionTool):
# Execute the tool's specific logic with valid credentials
return await self._run_async_with_credential(
credentials, args, tool_context
credentials, self._tool_config, args, tool_context
)
except Exception as ex:
@@ -96,6 +103,7 @@ class BigQueryTool(FunctionTool):
async def _run_async_with_credential(
self,
credentials: Credentials,
tool_config: BigQueryToolConfig,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
@@ -113,4 +121,6 @@ class BigQueryTool(FunctionTool):
signature = inspect.signature(self.func)
if "credentials" in signature.parameters:
args_to_call["credentials"] = credentials
if "config" in signature.parameters:
args_to_call["config"] = tool_config
return await super().run_async(args=args_to_call, tool_context=tool_context)
@@ -28,6 +28,7 @@ from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from .bigquery_credentials import BigQueryCredentialsConfig
from .bigquery_tool import BigQueryTool
from .config import BigQueryToolConfig
class BigQueryToolset(BaseToolset):
@@ -38,9 +39,11 @@ class BigQueryToolset(BaseToolset):
*,
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
credentials_config: Optional[BigQueryCredentialsConfig] = None,
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
):
self._credentials_config = credentials_config
self.tool_filter = tool_filter
self._credentials_config = credentials_config
self._tool_config = bigquery_tool_config
def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
@@ -64,14 +67,15 @@ class BigQueryToolset(BaseToolset):
all_tools = [
BigQueryTool(
func=func,
credentials=self._credentials_config,
credentials_config=self._credentials_config,
bigquery_tool_config=self._tool_config,
)
for func in [
metadata_tool.get_dataset_info,
metadata_tool.get_table_info,
metadata_tool.list_dataset_ids,
metadata_tool.list_table_ids,
query_tool.execute_sql,
query_tool.get_execute_sql(self._tool_config),
]
]
+46
View File
@@ -0,0 +1,46 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel
from ...utils.feature_decorator import experimental
class WriteMode(Enum):
"""Write mode indicating what levels of write operations are allowed in BigQuery."""
BLOCKED = 'blocked'
"""No write operations are allowed.
This mode implies that only read (i.e. SELECT query) operations are allowed.
"""
ALLOWED = 'allowed'
"""All write operations are allowed."""
@experimental('Config defaults may have breaking change in the future.')
class BigQueryToolConfig(BaseModel):
"""Configuration for BigQuery tools."""
write_mode: WriteMode = WriteMode.BLOCKED
"""Write mode for BigQuery tools.
By default, the tool will allow only read operations. This behaviour may
change in future versions.
"""
@@ -15,7 +15,7 @@
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from ...tools.bigquery import client
from . import client
def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
+138 -22
View File
@@ -12,14 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import types
from typing import Callable
from google.cloud import bigquery
from google.oauth2.credentials import Credentials
from ...tools.bigquery import client
from . import client
from .config import BigQueryToolConfig
from .config import WriteMode
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
def execute_sql(
project_id: str,
query: str,
credentials: Credentials,
config: BigQueryToolConfig,
) -> dict:
"""Run a BigQuery SQL query in the project and return the result.
Args:
@@ -35,34 +47,49 @@ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
query not returned in the result.
Examples:
>>> execute_sql("bigframes-dev",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"rows": [
{
"island": "Dream",
"population": 124
},
{
"island": "Biscoe",
"population": 168
},
{
"island": "Torgersen",
"population": 52
}
]
}
Fetch data or insights from a table:
>>> execute_sql("bigframes-dev",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "ERROR",
"rows": [
{
"island": "Dream",
"population": 124
},
{
"island": "Biscoe",
"population": 168
},
{
"island": "Torgersen",
"population": 52
}
]
}
"""
try:
bq_client = client.get_bigquery_client(credentials=credentials)
if not config or config.write_mode == WriteMode.BLOCKED:
query_job = bq_client.query(
query,
project=project_id,
job_config=bigquery.QueryJobConfig(dry_run=True),
)
if query_job.statement_type != "SELECT":
return {
"status": "ERROR",
"error_details": "Read-only mode only supports SELECT statements.",
}
row_iterator = bq_client.query_and_wait(
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
)
rows = [{key: val for key, val in row.items()} for row in row_iterator]
result = {"rows": rows}
result = {"status": "SUCCESS", "rows": rows}
if (
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
@@ -74,3 +101,92 @@ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
"status": "ERROR",
"error_details": str(ex),
}
_execute_sql_write_examples = """
Create a table from the result of a query:
>>> execute_sql("bigframes-dev",
... "CREATE TABLE my_project.my_dataset.my_table AS "
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "SUCCESS",
"rows": []
}
Delete a table:
>>> execute_sql("bigframes-dev",
... "DROP TABLE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
"rows": []
}
Copy a table to another table:
>>> execute_sql("bigframes-dev",
... "CREATE TABLE my_project.my_dataset.my_table_clone "
... "CLONE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
"rows": []
}
Create a snapshot (a lightweight, read-optimized copy) of en existing
table:
>>> execute_sql("bigframes-dev",
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
... "CLONE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
"rows": []
}
Notes:
- If a destination table already exists, there are a few ways to overwrite
it:
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
- First run "DROP TABLE", followed by "CREATE TABLE".
- To insert data into a table, use "INSERT INTO" statement.
"""
def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
"""Get the execute_sql tool customized as per the given tool config.
Args:
config: BigQuery tool configuration indicating the behavior of the
execute_sql tool.
Returns:
callable[..., dict]: A version of the execute_sql tool respecting the tool
config.
"""
if not config or config.write_mode == WriteMode.BLOCKED:
return execute_sql
# Create a new function object using the original function's code and globals.
# We pass the original code, globals, name, defaults, and closure.
# This creates a raw function object without copying other metadata yet.
execute_sql_wrapper = types.FunctionType(
execute_sql.__code__,
execute_sql.__globals__,
execute_sql.__name__,
execute_sql.__defaults__,
execute_sql.__closure__,
)
# Use functools.update_wrapper to copy over other essential attributes
# from the original function to the new one.
# This includes __name__, __qualname__, __module__, __annotations__, etc.
# It specifically allows us to then set __doc__ separately.
functools.update_wrapper(execute_sql_wrapper, execute_sql)
# Now, set the new docstring
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
return execute_sql_wrapper
@@ -0,0 +1,220 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import textwrap
from typing import Optional
from google.adk.tools import BaseTool
from google.adk.tools.bigquery import BigQueryCredentialsConfig
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig
from google.adk.tools.bigquery.config import WriteMode
import pytest
async def get_tool(
name: str, tool_config: Optional[BigQueryToolConfig] = None
) -> BaseTool:
"""Get a tool from BigQuery toolset.
This method gets the tool view that an Agent using the BigQuery toolset would
see.
Returns:
The tool.
"""
credentials_config = BigQueryCredentialsConfig(
client_id="abc", client_secret="def"
)
toolset = BigQueryToolset(
credentials_config=credentials_config,
tool_filter=[name],
bigquery_tool_config=tool_config,
)
tools = await toolset.get_tools()
assert tools is not None
assert len(tools) == 1
return tools[0]
@pytest.mark.parametrize(
("tool_config",),
[
pytest.param(None, id="no-config"),
pytest.param(BigQueryToolConfig(), id="default-config"),
pytest.param(
BigQueryToolConfig(write_mode=WriteMode.BLOCKED),
id="explicit-no-write",
),
],
)
@pytest.mark.asyncio
async def test_execute_sql_declaration_read_only(tool_config):
"""Test BigQuery execute_sql tool declaration in read-only mode.
This test verifies that the execute_sql tool declaration reflects the
read-only capability.
"""
tool_name = "execute_sql"
tool = await get_tool(tool_name, tool_config)
assert tool.name == tool_name
assert tool.description == textwrap.dedent("""\
Run a BigQuery SQL query in the project and return the result.
Args:
project_id (str): The GCP project id in which the query should be
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
Fetch data or insights from a table:
>>> execute_sql("bigframes-dev",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "ERROR",
"rows": [
{
"island": "Dream",
"population": 124
},
{
"island": "Biscoe",
"population": 168
},
{
"island": "Torgersen",
"population": 52
}
]
}""")
@pytest.mark.parametrize(
("tool_config",),
[
pytest.param(
BigQueryToolConfig(write_mode=WriteMode.ALLOWED),
id="explicit-all-write",
),
],
)
@pytest.mark.asyncio
async def test_execute_sql_declaration_write(tool_config):
"""Test BigQuery execute_sql tool declaration with all writes enabled.
This test verifies that the execute_sql tool declaration reflects the write
capability.
"""
tool_name = "execute_sql"
tool = await get_tool(tool_name, tool_config)
assert tool.name == tool_name
assert tool.description == textwrap.dedent("""\
Run a BigQuery SQL query in the project and return the result.
Args:
project_id (str): The GCP project id in which the query should be
executed.
query (str): The BigQuery SQL query to be executed.
credentials (Credentials): The credentials to use for the request.
Returns:
dict: Dictionary representing the result of the query.
If the result contains the key "result_is_likely_truncated" with
value True, it means that there may be additional rows matching the
query not returned in the result.
Examples:
Fetch data or insights from a table:
>>> execute_sql("bigframes-dev",
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "ERROR",
"rows": [
{
"island": "Dream",
"population": 124
},
{
"island": "Biscoe",
"population": 168
},
{
"island": "Torgersen",
"population": 52
}
]
}
Create a table from the result of a query:
>>> execute_sql("bigframes-dev",
... "CREATE TABLE my_project.my_dataset.my_table AS "
... "SELECT island, COUNT(*) AS population "
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
{
"status": "SUCCESS",
"rows": []
}
Delete a table:
>>> execute_sql("bigframes-dev",
... "DROP TABLE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
"rows": []
}
Copy a table to another table:
>>> execute_sql("bigframes-dev",
... "CREATE TABLE my_project.my_dataset.my_table_clone "
... "CLONE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
"rows": []
}
Create a snapshot (a lightweight, read-optimized copy) of en existing
table:
>>> execute_sql("bigframes-dev",
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
... "CLONE my_project.my_dataset.my_table")
{
"status": "SUCCESS",
"rows": []
}
Notes:
- If a destination table already exists, there are a few ways to overwrite
it:
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
- First run "DROP TABLE", followed by "CREATE TABLE".
- To insert data into a table, use "INSERT INTO" statement.""")
@@ -92,11 +92,13 @@ class TestBigQueryTool:
The tool should properly inherit from FunctionTool while adding
Google API specific credential management capabilities.
"""
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
tool = BigQueryTool(
func=sample_function, credentials_config=credentials_config
)
assert tool.func == sample_function
assert tool.credentials_manager is not None
assert isinstance(tool.credentials_manager, BigQueryCredentialsManager)
assert tool._credentials_manager is not None
assert isinstance(tool._credentials_manager, BigQueryCredentialsManager)
# Verify that 'credentials' parameter is ignored in function signature analysis
assert "credentials" in tool._ignore_params
@@ -106,10 +108,10 @@ class TestBigQueryTool:
Some tools might handle authentication externally or use service
accounts, so credential management should be optional.
"""
tool = BigQueryTool(func=sample_function, credentials=None)
tool = BigQueryTool(func=sample_function, credentials_config=None)
assert tool.func == sample_function
assert tool.credentials_manager is None
assert tool._credentials_manager is None
@pytest.mark.asyncio
async def test_run_async_with_valid_credentials(
@@ -120,12 +122,14 @@ class TestBigQueryTool:
This tests the main happy path where credentials are available
and the underlying function executes successfully.
"""
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
tool = BigQueryTool(
func=sample_function, credentials_config=credentials_config
)
# Mock the credentials manager to return valid credentials
mock_creds = Mock(spec=Credentials)
with patch.object(
tool.credentials_manager,
tool._credentials_manager,
"get_valid_credentials",
return_value=mock_creds,
) as mock_get_creds:
@@ -147,11 +151,13 @@ class TestBigQueryTool:
When credentials aren't available and OAuth flow is needed,
the tool should return a user-friendly message rather than failing.
"""
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
tool = BigQueryTool(
func=sample_function, credentials_config=credentials_config
)
# Mock credentials manager to return None (OAuth flow in progress)
with patch.object(
tool.credentials_manager, "get_valid_credentials", return_value=None
tool._credentials_manager, "get_valid_credentials", return_value=None
) as mock_get_creds:
result = await tool.run_async(
@@ -171,7 +177,7 @@ class TestBigQueryTool:
Tools without credential managers should execute normally,
passing None for credentials if the function accepts them.
"""
tool = BigQueryTool(func=sample_function, credentials=None)
tool = BigQueryTool(func=sample_function, credentials_config=None)
result = await tool.run_async(
args={"param1": "test_value"}, tool_context=mock_tool_context
@@ -190,12 +196,12 @@ class TestBigQueryTool:
which is important for tools that make async API calls.
"""
tool = BigQueryTool(
func=async_sample_function, credentials=credentials_config
func=async_sample_function, credentials_config=credentials_config
)
mock_creds = Mock(spec=Credentials)
with patch.object(
tool.credentials_manager,
tool._credentials_manager,
"get_valid_credentials",
return_value=mock_creds,
):
@@ -220,11 +226,13 @@ class TestBigQueryTool:
def failing_function(param1: str, credentials: Credentials = None) -> dict:
raise ValueError("Something went wrong")
tool = BigQueryTool(func=failing_function, credentials=credentials_config)
tool = BigQueryTool(
func=failing_function, credentials_config=credentials_config
)
mock_creds = Mock(spec=Credentials)
with patch.object(
tool.credentials_manager,
tool._credentials_manager,
"get_valid_credentials",
return_value=mock_creds,
):
@@ -250,7 +258,9 @@ class TestBigQueryTool:
) -> dict:
return {"success": True}
tool = BigQueryTool(func=complex_function, credentials=credentials_config)
tool = BigQueryTool(
func=complex_function, credentials_config=credentials_config
)
# The 'credentials' parameter should be ignored in mandatory args analysis
mandatory_args = tool._get_mandatory_args()
@@ -0,0 +1,27 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.tools.bigquery.config import BigQueryToolConfig
import pytest
def test_bigquery_tool_config_experimental_warning():
"""Test BigQueryToolConfig experimental warning."""
with pytest.warns(
UserWarning,
match="Config defaults may have breaking change in the future.",
):
BigQueryToolConfig()