You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: allow setting max_billed_bytes in BigQuery tools config
This will allow users to configure a limit to access in ADK tools on the charges for queries. PiperOrigin-RevId: 831921163
This commit is contained in:
committed by
Copybara-Service
parent
22eb7e5b06
commit
ffbb0b37e1
@@ -61,6 +61,14 @@ class BigQueryToolConfig(BaseModel):
|
||||
change in future versions.
|
||||
"""
|
||||
|
||||
maximum_bytes_billed: Optional[int] = None
|
||||
"""Maximum number of bytes to bill for a query.
|
||||
|
||||
In BigQuery on-demand pricing, charges are rounded up to the nearest MB, with
|
||||
a minimum 10 MB data processed per table referenced by the query, and with a
|
||||
minimum 10 MB data processed per query. So this value must be set >=10485760.
|
||||
"""
|
||||
|
||||
max_query_result_rows: int = 50
|
||||
"""Maximum number of rows to return from a query.
|
||||
|
||||
@@ -91,6 +99,19 @@ class BigQueryToolConfig(BaseModel):
|
||||
locations, see https://cloud.google.com/bigquery/docs/locations.
|
||||
"""
|
||||
|
||||
@field_validator('maximum_bytes_billed')
|
||||
@classmethod
|
||||
def validate_maximum_bytes_billed(cls, v):
|
||||
"""Validate the maximum bytes billed."""
|
||||
if v and v < 10_485_760:
|
||||
raise ValueError(
|
||||
'In BigQuery on-demand pricing, charges are rounded up to the nearest'
|
||||
' MB, with a minimum 10 MB data processed per table referenced by the'
|
||||
' query, and with a minimum 10 MB data processed per query. So'
|
||||
' max_bytes_billed must be set >=10485760.'
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator('application_name')
|
||||
@classmethod
|
||||
def validate_application_name(cls, v):
|
||||
|
||||
@@ -152,12 +152,15 @@ def _execute_sql(
|
||||
return {"status": "SUCCESS", "dry_run_info": dry_run_job.to_api_repr()}
|
||||
|
||||
# Finally execute the query, fetch the result, and return it
|
||||
job_config = bigquery.QueryJobConfig(
|
||||
connection_properties=bq_connection_properties,
|
||||
labels=bq_job_labels,
|
||||
)
|
||||
if settings.maximum_bytes_billed:
|
||||
job_config.maximum_bytes_billed = settings.maximum_bytes_billed
|
||||
row_iterator = bq_client.query_and_wait(
|
||||
query,
|
||||
job_config=bigquery.QueryJobConfig(
|
||||
connection_properties=bq_connection_properties,
|
||||
labels=bq_job_labels,
|
||||
),
|
||||
job_config=job_config,
|
||||
project=project_id,
|
||||
max_results=settings.max_query_result_rows,
|
||||
)
|
||||
|
||||
@@ -1826,3 +1826,26 @@ def test_execute_sql_no_truncation():
|
||||
# Check no truncation flag when fewer rows than limit
|
||||
assert result["status"] == "SUCCESS"
|
||||
assert "result_is_likely_truncated" not in result
|
||||
|
||||
|
||||
def test_execute_sql_maximum_bytes_billed_config():
|
||||
"""Test execute_sql tool respects maximum_bytes_billed from config."""
|
||||
project = "my_project"
|
||||
query = "SELECT 123 AS num"
|
||||
statement_type = "SELECT"
|
||||
credentials = mock.create_autospec(Credentials, instance=True)
|
||||
tool_config = BigQueryToolConfig(maximum_bytes_billed=11_000_000)
|
||||
tool_context = mock.create_autospec(ToolContext, instance=True)
|
||||
|
||||
with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client:
|
||||
bq_client = Client.return_value
|
||||
query_job = mock.create_autospec(bigquery.QueryJob)
|
||||
query_job.statement_type = statement_type
|
||||
bq_client.query.return_value = query_job
|
||||
|
||||
execute_sql(project, query, credentials, tool_config, tool_context)
|
||||
|
||||
# Check that maximum_bytes_billed was called with config value
|
||||
bq_client.query_and_wait.assert_called_once()
|
||||
call_args = bq_client.query_and_wait.call_args
|
||||
assert call_args.kwargs["job_config"].maximum_bytes_billed == 11_000_000
|
||||
|
||||
@@ -56,3 +56,24 @@ def test_bigquery_tool_config_max_query_result_rows_custom():
|
||||
with pytest.warns(UserWarning):
|
||||
config = BigQueryToolConfig(max_query_result_rows=100)
|
||||
assert config.max_query_result_rows == 100
|
||||
|
||||
|
||||
def test_bigquery_tool_config_valid_maximum_bytes_billed():
|
||||
"""Test BigQueryToolConfig raises exception with valid max bytes billed."""
|
||||
with pytest.warns(UserWarning):
|
||||
config = BigQueryToolConfig(maximum_bytes_billed=10_485_760)
|
||||
assert config.maximum_bytes_billed == 10_485_760
|
||||
|
||||
|
||||
def test_bigquery_tool_config_invalid_maximum_bytes_billed():
|
||||
"""Test BigQueryToolConfig raises exception with invalid max bytes billed."""
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"In BigQuery on-demand pricing, charges are rounded up to the nearest"
|
||||
" MB, with a minimum 10 MB data processed per table referenced by the"
|
||||
" query, and with a minimum 10 MB data processed per query. So"
|
||||
" max_bytes_billed must be set >=10485760."
|
||||
),
|
||||
):
|
||||
BigQueryToolConfig(maximum_bytes_billed=10_485_759)
|
||||
|
||||
Reference in New Issue
Block a user