You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Introduce write protected mode to BigQuery tools
This allows to protect against any write operations (e.g. update or delete a table), useful for some agents that must only be used in a read-only mode, while the user may have write permissions. PiperOrigin-RevId: 769803741
This commit is contained in:
committed by
Copybara-Service
parent
77f44a4e45
commit
6c999caa41
@@ -17,11 +17,15 @@ import os
|
||||
from google.adk.agents import llm_agent
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.bigquery.config import WriteMode
|
||||
import google.auth
|
||||
|
||||
RUN_WITH_ADC = False
|
||||
|
||||
|
||||
tool_config = BigQueryToolConfig(write_mode=WriteMode.ALLOWED)
|
||||
|
||||
if RUN_WITH_ADC:
|
||||
# Initialize the tools to use the application default credentials.
|
||||
application_default_credentials, _ = google.auth.default()
|
||||
@@ -37,7 +41,9 @@ else:
|
||||
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
|
||||
)
|
||||
|
||||
bigquery_toolset = BigQueryToolset(credentials_config=credentials_config)
|
||||
bigquery_toolset = BigQueryToolset(
|
||||
credentials_config=credentials_config, bigquery_tool_config=tool_config
|
||||
)
|
||||
|
||||
# The variable name `root_agent` determines what your root agent is for the
|
||||
# debug CLI
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..function_tool import FunctionTool
|
||||
from ..tool_context import ToolContext
|
||||
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||
from .bigquery_credentials import BigQueryCredentialsManager
|
||||
from .config import BigQueryToolConfig
|
||||
|
||||
|
||||
class BigQueryTool(FunctionTool):
|
||||
@@ -41,21 +42,27 @@ class BigQueryTool(FunctionTool):
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
credentials: Optional[BigQueryCredentialsConfig] = None,
|
||||
*,
|
||||
credentials_config: Optional[BigQueryCredentialsConfig] = None,
|
||||
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
|
||||
):
|
||||
"""Initialize the Google API tool.
|
||||
|
||||
Args:
|
||||
func: callable that impelments the tool's logic, can accept one
|
||||
'credential" parameter
|
||||
credentials: credentials used to call Google API. If None, then we don't
|
||||
hanlde the auth logic
|
||||
credentials_config: credentials config used to call Google API. If None,
|
||||
then we don't hanlde the auth logic
|
||||
"""
|
||||
super().__init__(func=func)
|
||||
self._ignore_params.append("credentials")
|
||||
self.credentials_manager = (
|
||||
BigQueryCredentialsManager(credentials) if credentials else None
|
||||
self._ignore_params.append("config")
|
||||
self._credentials_manager = (
|
||||
BigQueryCredentialsManager(credentials_config)
|
||||
if credentials_config
|
||||
else None
|
||||
)
|
||||
self._tool_config = bigquery_tool_config
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
@@ -69,12 +76,12 @@ class BigQueryTool(FunctionTool):
|
||||
try:
|
||||
# Get valid credentials
|
||||
credentials = (
|
||||
await self.credentials_manager.get_valid_credentials(tool_context)
|
||||
if self.credentials_manager
|
||||
await self._credentials_manager.get_valid_credentials(tool_context)
|
||||
if self._credentials_manager
|
||||
else None
|
||||
)
|
||||
|
||||
if credentials is None and self.credentials_manager:
|
||||
if credentials is None and self._credentials_manager:
|
||||
# OAuth flow in progress
|
||||
return (
|
||||
"User authorization is required to access Google services for"
|
||||
@@ -84,7 +91,7 @@ class BigQueryTool(FunctionTool):
|
||||
# Execute the tool's specific logic with valid credentials
|
||||
|
||||
return await self._run_async_with_credential(
|
||||
credentials, args, tool_context
|
||||
credentials, self._tool_config, args, tool_context
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
@@ -96,6 +103,7 @@ class BigQueryTool(FunctionTool):
|
||||
async def _run_async_with_credential(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
tool_config: BigQueryToolConfig,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
@@ -113,4 +121,6 @@ class BigQueryTool(FunctionTool):
|
||||
signature = inspect.signature(self.func)
|
||||
if "credentials" in signature.parameters:
|
||||
args_to_call["credentials"] = credentials
|
||||
if "config" in signature.parameters:
|
||||
args_to_call["config"] = tool_config
|
||||
return await super().run_async(args=args_to_call, tool_context=tool_context)
|
||||
|
||||
@@ -28,6 +28,7 @@ from ...tools.base_toolset import BaseToolset
|
||||
from ...tools.base_toolset import ToolPredicate
|
||||
from .bigquery_credentials import BigQueryCredentialsConfig
|
||||
from .bigquery_tool import BigQueryTool
|
||||
from .config import BigQueryToolConfig
|
||||
|
||||
|
||||
class BigQueryToolset(BaseToolset):
|
||||
@@ -38,9 +39,11 @@ class BigQueryToolset(BaseToolset):
|
||||
*,
|
||||
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
|
||||
credentials_config: Optional[BigQueryCredentialsConfig] = None,
|
||||
bigquery_tool_config: Optional[BigQueryToolConfig] = None,
|
||||
):
|
||||
self._credentials_config = credentials_config
|
||||
self.tool_filter = tool_filter
|
||||
self._credentials_config = credentials_config
|
||||
self._tool_config = bigquery_tool_config
|
||||
|
||||
def _is_tool_selected(
|
||||
self, tool: BaseTool, readonly_context: ReadonlyContext
|
||||
@@ -64,14 +67,15 @@ class BigQueryToolset(BaseToolset):
|
||||
all_tools = [
|
||||
BigQueryTool(
|
||||
func=func,
|
||||
credentials=self._credentials_config,
|
||||
credentials_config=self._credentials_config,
|
||||
bigquery_tool_config=self._tool_config,
|
||||
)
|
||||
for func in [
|
||||
metadata_tool.get_dataset_info,
|
||||
metadata_tool.get_table_info,
|
||||
metadata_tool.list_dataset_ids,
|
||||
metadata_tool.list_table_ids,
|
||||
query_tool.execute_sql,
|
||||
query_tool.get_execute_sql(self._tool_config),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...utils.feature_decorator import experimental
|
||||
|
||||
|
||||
class WriteMode(Enum):
|
||||
"""Write mode indicating what levels of write operations are allowed in BigQuery."""
|
||||
|
||||
BLOCKED = 'blocked'
|
||||
"""No write operations are allowed.
|
||||
|
||||
This mode implies that only read (i.e. SELECT query) operations are allowed.
|
||||
"""
|
||||
|
||||
ALLOWED = 'allowed'
|
||||
"""All write operations are allowed."""
|
||||
|
||||
|
||||
@experimental('Config defaults may have breaking change in the future.')
|
||||
class BigQueryToolConfig(BaseModel):
|
||||
"""Configuration for BigQuery tools."""
|
||||
|
||||
write_mode: WriteMode = WriteMode.BLOCKED
|
||||
"""Write mode for BigQuery tools.
|
||||
|
||||
By default, the tool will allow only read operations. This behaviour may
|
||||
change in future versions.
|
||||
"""
|
||||
@@ -15,7 +15,7 @@
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ...tools.bigquery import client
|
||||
from . import client
|
||||
|
||||
|
||||
def list_dataset_ids(project_id: str, credentials: Credentials) -> list[str]:
|
||||
|
||||
@@ -12,14 +12,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import types
|
||||
from typing import Callable
|
||||
|
||||
from google.cloud import bigquery
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from ...tools.bigquery import client
|
||||
from . import client
|
||||
from .config import BigQueryToolConfig
|
||||
from .config import WriteMode
|
||||
|
||||
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
|
||||
|
||||
|
||||
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
|
||||
def execute_sql(
|
||||
project_id: str,
|
||||
query: str,
|
||||
credentials: Credentials,
|
||||
config: BigQueryToolConfig,
|
||||
) -> dict:
|
||||
"""Run a BigQuery SQL query in the project and return the result.
|
||||
|
||||
Args:
|
||||
@@ -35,34 +47,49 @@ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"rows": [
|
||||
{
|
||||
"island": "Dream",
|
||||
"population": 124
|
||||
},
|
||||
{
|
||||
"island": "Biscoe",
|
||||
"population": 168
|
||||
},
|
||||
{
|
||||
"island": "Torgersen",
|
||||
"population": 52
|
||||
}
|
||||
]
|
||||
}
|
||||
Fetch data or insights from a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "ERROR",
|
||||
"rows": [
|
||||
{
|
||||
"island": "Dream",
|
||||
"population": 124
|
||||
},
|
||||
{
|
||||
"island": "Biscoe",
|
||||
"population": 168
|
||||
},
|
||||
{
|
||||
"island": "Torgersen",
|
||||
"population": 52
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
try:
|
||||
bq_client = client.get_bigquery_client(credentials=credentials)
|
||||
if not config or config.write_mode == WriteMode.BLOCKED:
|
||||
query_job = bq_client.query(
|
||||
query,
|
||||
project=project_id,
|
||||
job_config=bigquery.QueryJobConfig(dry_run=True),
|
||||
)
|
||||
if query_job.statement_type != "SELECT":
|
||||
return {
|
||||
"status": "ERROR",
|
||||
"error_details": "Read-only mode only supports SELECT statements.",
|
||||
}
|
||||
|
||||
row_iterator = bq_client.query_and_wait(
|
||||
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
||||
)
|
||||
rows = [{key: val for key, val in row.items()} for row in row_iterator]
|
||||
result = {"rows": rows}
|
||||
result = {"status": "SUCCESS", "rows": rows}
|
||||
if (
|
||||
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
|
||||
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
||||
@@ -74,3 +101,92 @@ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
|
||||
"status": "ERROR",
|
||||
"error_details": str(ex),
|
||||
}
|
||||
|
||||
|
||||
_execute_sql_write_examples = """
|
||||
Create a table from the result of a query:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table AS "
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Delete a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "DROP TABLE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Copy a table to another table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table_clone "
|
||||
... "CLONE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Create a snapshot (a lightweight, read-optimized copy) of en existing
|
||||
table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
|
||||
... "CLONE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Notes:
|
||||
- If a destination table already exists, there are a few ways to overwrite
|
||||
it:
|
||||
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
|
||||
- First run "DROP TABLE", followed by "CREATE TABLE".
|
||||
- To insert data into a table, use "INSERT INTO" statement.
|
||||
"""
|
||||
|
||||
|
||||
def get_execute_sql(config: BigQueryToolConfig) -> Callable[..., dict]:
|
||||
"""Get the execute_sql tool customized as per the given tool config.
|
||||
|
||||
Args:
|
||||
config: BigQuery tool configuration indicating the behavior of the
|
||||
execute_sql tool.
|
||||
|
||||
Returns:
|
||||
callable[..., dict]: A version of the execute_sql tool respecting the tool
|
||||
config.
|
||||
"""
|
||||
|
||||
if not config or config.write_mode == WriteMode.BLOCKED:
|
||||
return execute_sql
|
||||
|
||||
# Create a new function object using the original function's code and globals.
|
||||
# We pass the original code, globals, name, defaults, and closure.
|
||||
# This creates a raw function object without copying other metadata yet.
|
||||
execute_sql_wrapper = types.FunctionType(
|
||||
execute_sql.__code__,
|
||||
execute_sql.__globals__,
|
||||
execute_sql.__name__,
|
||||
execute_sql.__defaults__,
|
||||
execute_sql.__closure__,
|
||||
)
|
||||
|
||||
# Use functools.update_wrapper to copy over other essential attributes
|
||||
# from the original function to the new one.
|
||||
# This includes __name__, __qualname__, __module__, __annotations__, etc.
|
||||
# It specifically allows us to then set __doc__ separately.
|
||||
functools.update_wrapper(execute_sql_wrapper, execute_sql)
|
||||
|
||||
# Now, set the new docstring
|
||||
execute_sql_wrapper.__doc__ += _execute_sql_write_examples
|
||||
|
||||
return execute_sql_wrapper
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.tools import BaseTool
|
||||
from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
||||
from google.adk.tools.bigquery import BigQueryToolset
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
from google.adk.tools.bigquery.config import WriteMode
|
||||
import pytest
|
||||
|
||||
|
||||
async def get_tool(
|
||||
name: str, tool_config: Optional[BigQueryToolConfig] = None
|
||||
) -> BaseTool:
|
||||
"""Get a tool from BigQuery toolset.
|
||||
|
||||
This method gets the tool view that an Agent using the BigQuery toolset would
|
||||
see.
|
||||
|
||||
Returns:
|
||||
The tool.
|
||||
"""
|
||||
credentials_config = BigQueryCredentialsConfig(
|
||||
client_id="abc", client_secret="def"
|
||||
)
|
||||
|
||||
toolset = BigQueryToolset(
|
||||
credentials_config=credentials_config,
|
||||
tool_filter=[name],
|
||||
bigquery_tool_config=tool_config,
|
||||
)
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
assert tools is not None
|
||||
assert len(tools) == 1
|
||||
return tools[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_config",),
|
||||
[
|
||||
pytest.param(None, id="no-config"),
|
||||
pytest.param(BigQueryToolConfig(), id="default-config"),
|
||||
pytest.param(
|
||||
BigQueryToolConfig(write_mode=WriteMode.BLOCKED),
|
||||
id="explicit-no-write",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_declaration_read_only(tool_config):
|
||||
"""Test BigQuery execute_sql tool declaration in read-only mode.
|
||||
|
||||
This test verifies that the execute_sql tool declaration reflects the
|
||||
read-only capability.
|
||||
"""
|
||||
tool_name = "execute_sql"
|
||||
tool = await get_tool(tool_name, tool_config)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == textwrap.dedent("""\
|
||||
Run a BigQuery SQL query in the project and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the query should be
|
||||
executed.
|
||||
query (str): The BigQuery SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the result of the query.
|
||||
If the result contains the key "result_is_likely_truncated" with
|
||||
value True, it means that there may be additional rows matching the
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
Fetch data or insights from a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "ERROR",
|
||||
"rows": [
|
||||
{
|
||||
"island": "Dream",
|
||||
"population": 124
|
||||
},
|
||||
{
|
||||
"island": "Biscoe",
|
||||
"population": 168
|
||||
},
|
||||
{
|
||||
"island": "Torgersen",
|
||||
"population": 52
|
||||
}
|
||||
]
|
||||
}""")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_config",),
|
||||
[
|
||||
pytest.param(
|
||||
BigQueryToolConfig(write_mode=WriteMode.ALLOWED),
|
||||
id="explicit-all-write",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_declaration_write(tool_config):
|
||||
"""Test BigQuery execute_sql tool declaration with all writes enabled.
|
||||
|
||||
This test verifies that the execute_sql tool declaration reflects the write
|
||||
capability.
|
||||
"""
|
||||
tool_name = "execute_sql"
|
||||
tool = await get_tool(tool_name, tool_config)
|
||||
assert tool.name == tool_name
|
||||
assert tool.description == textwrap.dedent("""\
|
||||
Run a BigQuery SQL query in the project and return the result.
|
||||
|
||||
Args:
|
||||
project_id (str): The GCP project id in which the query should be
|
||||
executed.
|
||||
query (str): The BigQuery SQL query to be executed.
|
||||
credentials (Credentials): The credentials to use for the request.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary representing the result of the query.
|
||||
If the result contains the key "result_is_likely_truncated" with
|
||||
value True, it means that there may be additional rows matching the
|
||||
query not returned in the result.
|
||||
|
||||
Examples:
|
||||
Fetch data or insights from a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "ERROR",
|
||||
"rows": [
|
||||
{
|
||||
"island": "Dream",
|
||||
"population": 124
|
||||
},
|
||||
{
|
||||
"island": "Biscoe",
|
||||
"population": 168
|
||||
},
|
||||
{
|
||||
"island": "Torgersen",
|
||||
"population": 52
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Create a table from the result of a query:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table AS "
|
||||
... "SELECT island, COUNT(*) AS population "
|
||||
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Delete a table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "DROP TABLE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Copy a table to another table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "CREATE TABLE my_project.my_dataset.my_table_clone "
|
||||
... "CLONE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Create a snapshot (a lightweight, read-optimized copy) of en existing
|
||||
table:
|
||||
|
||||
>>> execute_sql("bigframes-dev",
|
||||
... "CREATE SNAPSHOT TABLE my_project.my_dataset.my_table_snapshot "
|
||||
... "CLONE my_project.my_dataset.my_table")
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"rows": []
|
||||
}
|
||||
|
||||
Notes:
|
||||
- If a destination table already exists, there are a few ways to overwrite
|
||||
it:
|
||||
- Use "CREATE OR REPLACE TABLE" instead of "CREATE TABLE".
|
||||
- First run "DROP TABLE", followed by "CREATE TABLE".
|
||||
- To insert data into a table, use "INSERT INTO" statement.""")
|
||||
@@ -92,11 +92,13 @@ class TestBigQueryTool:
|
||||
The tool should properly inherit from FunctionTool while adding
|
||||
Google API specific credential management capabilities.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
|
||||
tool = BigQueryTool(
|
||||
func=sample_function, credentials_config=credentials_config
|
||||
)
|
||||
|
||||
assert tool.func == sample_function
|
||||
assert tool.credentials_manager is not None
|
||||
assert isinstance(tool.credentials_manager, BigQueryCredentialsManager)
|
||||
assert tool._credentials_manager is not None
|
||||
assert isinstance(tool._credentials_manager, BigQueryCredentialsManager)
|
||||
# Verify that 'credentials' parameter is ignored in function signature analysis
|
||||
assert "credentials" in tool._ignore_params
|
||||
|
||||
@@ -106,10 +108,10 @@ class TestBigQueryTool:
|
||||
Some tools might handle authentication externally or use service
|
||||
accounts, so credential management should be optional.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=None)
|
||||
tool = BigQueryTool(func=sample_function, credentials_config=None)
|
||||
|
||||
assert tool.func == sample_function
|
||||
assert tool.credentials_manager is None
|
||||
assert tool._credentials_manager is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_with_valid_credentials(
|
||||
@@ -120,12 +122,14 @@ class TestBigQueryTool:
|
||||
This tests the main happy path where credentials are available
|
||||
and the underlying function executes successfully.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
|
||||
tool = BigQueryTool(
|
||||
func=sample_function, credentials_config=credentials_config
|
||||
)
|
||||
|
||||
# Mock the credentials manager to return valid credentials
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
with patch.object(
|
||||
tool.credentials_manager,
|
||||
tool._credentials_manager,
|
||||
"get_valid_credentials",
|
||||
return_value=mock_creds,
|
||||
) as mock_get_creds:
|
||||
@@ -147,11 +151,13 @@ class TestBigQueryTool:
|
||||
When credentials aren't available and OAuth flow is needed,
|
||||
the tool should return a user-friendly message rather than failing.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=credentials_config)
|
||||
tool = BigQueryTool(
|
||||
func=sample_function, credentials_config=credentials_config
|
||||
)
|
||||
|
||||
# Mock credentials manager to return None (OAuth flow in progress)
|
||||
with patch.object(
|
||||
tool.credentials_manager, "get_valid_credentials", return_value=None
|
||||
tool._credentials_manager, "get_valid_credentials", return_value=None
|
||||
) as mock_get_creds:
|
||||
|
||||
result = await tool.run_async(
|
||||
@@ -171,7 +177,7 @@ class TestBigQueryTool:
|
||||
Tools without credential managers should execute normally,
|
||||
passing None for credentials if the function accepts them.
|
||||
"""
|
||||
tool = BigQueryTool(func=sample_function, credentials=None)
|
||||
tool = BigQueryTool(func=sample_function, credentials_config=None)
|
||||
|
||||
result = await tool.run_async(
|
||||
args={"param1": "test_value"}, tool_context=mock_tool_context
|
||||
@@ -190,12 +196,12 @@ class TestBigQueryTool:
|
||||
which is important for tools that make async API calls.
|
||||
"""
|
||||
tool = BigQueryTool(
|
||||
func=async_sample_function, credentials=credentials_config
|
||||
func=async_sample_function, credentials_config=credentials_config
|
||||
)
|
||||
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
with patch.object(
|
||||
tool.credentials_manager,
|
||||
tool._credentials_manager,
|
||||
"get_valid_credentials",
|
||||
return_value=mock_creds,
|
||||
):
|
||||
@@ -220,11 +226,13 @@ class TestBigQueryTool:
|
||||
def failing_function(param1: str, credentials: Credentials = None) -> dict:
|
||||
raise ValueError("Something went wrong")
|
||||
|
||||
tool = BigQueryTool(func=failing_function, credentials=credentials_config)
|
||||
tool = BigQueryTool(
|
||||
func=failing_function, credentials_config=credentials_config
|
||||
)
|
||||
|
||||
mock_creds = Mock(spec=Credentials)
|
||||
with patch.object(
|
||||
tool.credentials_manager,
|
||||
tool._credentials_manager,
|
||||
"get_valid_credentials",
|
||||
return_value=mock_creds,
|
||||
):
|
||||
@@ -250,7 +258,9 @@ class TestBigQueryTool:
|
||||
) -> dict:
|
||||
return {"success": True}
|
||||
|
||||
tool = BigQueryTool(func=complex_function, credentials=credentials_config)
|
||||
tool = BigQueryTool(
|
||||
func=complex_function, credentials_config=credentials_config
|
||||
)
|
||||
|
||||
# The 'credentials' parameter should be ignored in mandatory args analysis
|
||||
mandatory_args = tool._get_mandatory_args()
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.tools.bigquery.config import BigQueryToolConfig
|
||||
import pytest
|
||||
|
||||
|
||||
def test_bigquery_tool_config_experimental_warning():
|
||||
"""Test BigQueryToolConfig experimental warning."""
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Config defaults may have breaking change in the future.",
|
||||
):
|
||||
BigQueryToolConfig()
|
||||
Reference in New Issue
Block a user