You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Enhance BigQuery plugin schema upgrades and error reporting
This change introduces several improvements to the BigQuery Agent Analytics Plugin: * **Fix 1 (High):** Error callbacks (`on_model_error_callback`, `on_tool_error_callback`) now emit `status="ERROR"` instead of defaulting to `"OK"`. * **Fix 2 (Medium):** Schema upgrade now detects missing sub-fields in nested RECORD columns via a new recursive helper. The version label is now stamped only after the `update_table` call succeeds, ensuring failures can be retried. * **Fix 3 (Medium):** Multi-loop `shutdown()` now drains batch processors on non-current event loops using `run_coroutine_threadsafe` before closing transports. * **Fix 4 (Medium):** Session state is truncated before logging to prevent oversized payloads. * **Fix 5 (Low):** String system prompts are now truncated during content parsing. * **Fix 6 (Low):** Removed the unused `_HITL_TOOL_NAMES` frozenset. Co-authored-by: Haiyuan Cao <haiyuan@google.com> PiperOrigin-RevId: 879147684
This commit is contained in:
committed by
Copybara-Service
parent
feefadfcc9
commit
bcf38fa2ba
@@ -28,6 +28,13 @@ import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
# Enable gRPC fork support so child processes created via os.fork()
|
||||
# can safely create new gRPC channels. Must be set before grpc's
|
||||
# C-core is loaded (which happens through the google.api_core
|
||||
# imports below). setdefault respects any explicit user override.
|
||||
os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "1")
|
||||
|
||||
import random
|
||||
import time
|
||||
from types import MappingProxyType
|
||||
@@ -76,19 +83,29 @@ tracer = trace.get_tracer(
|
||||
_SCHEMA_VERSION = "1"
|
||||
_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version"
|
||||
|
||||
# Human-in-the-loop (HITL) tool names that receive additional
|
||||
# dedicated event types alongside the normal TOOL_* events.
|
||||
_HITL_TOOL_NAMES = frozenset({
|
||||
"adk_request_credential",
|
||||
"adk_request_confirmation",
|
||||
"adk_request_input",
|
||||
})
|
||||
_HITL_EVENT_MAP = MappingProxyType({
|
||||
"adk_request_credential": "HITL_CREDENTIAL_REQUEST",
|
||||
"adk_request_confirmation": "HITL_CONFIRMATION_REQUEST",
|
||||
"adk_request_input": "HITL_INPUT_REQUEST",
|
||||
})
|
||||
|
||||
# Track all living plugin instances so the fork handler can reset
|
||||
# them proactively in the child, before _ensure_started runs.
|
||||
_LIVE_PLUGINS: weakref.WeakSet = weakref.WeakSet()
|
||||
|
||||
|
||||
def _after_fork_in_child() -> None:
|
||||
"""Reset every living plugin instance after os.fork()."""
|
||||
for plugin in list(_LIVE_PLUGINS):
|
||||
try:
|
||||
plugin._reset_runtime_state()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if hasattr(os, "register_at_fork"):
|
||||
os.register_at_fork(after_in_child=_after_fork_in_child)
|
||||
|
||||
|
||||
def _safe_callback(func):
|
||||
"""Decorator that catches and logs exceptions in plugin callbacks.
|
||||
@@ -1407,7 +1424,10 @@ class HybridContentParser:
|
||||
if content.config and getattr(content.config, "system_instruction", None):
|
||||
si = content.config.system_instruction
|
||||
if isinstance(si, str):
|
||||
json_payload["system_prompt"] = si
|
||||
truncated_si, trunc = process_text(si)
|
||||
if trunc:
|
||||
is_truncated = True
|
||||
json_payload["system_prompt"] = truncated_si
|
||||
else:
|
||||
summary, parts, trunc = await self._parse_content_object(si)
|
||||
if trunc:
|
||||
@@ -1855,6 +1875,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
self._schema = None
|
||||
self.arrow_schema = None
|
||||
self._init_pid = os.getpid()
|
||||
_LIVE_PLUGINS.add(self)
|
||||
|
||||
def _cleanup_stale_loop_states(self) -> None:
|
||||
"""Removes entries for event loops that have been closed."""
|
||||
@@ -2142,9 +2163,73 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _schema_fields_match(
|
||||
existing: list[bq_schema.SchemaField],
|
||||
desired: list[bq_schema.SchemaField],
|
||||
) -> tuple[
|
||||
list[bq_schema.SchemaField],
|
||||
list[bq_schema.SchemaField],
|
||||
]:
|
||||
"""Compares existing vs desired schema fields recursively.
|
||||
|
||||
Returns:
|
||||
A tuple of (new_top_level_fields, updated_record_fields).
|
||||
``new_top_level_fields`` are fields in *desired* that are
|
||||
entirely absent from *existing*.
|
||||
``updated_record_fields`` are RECORD fields that exist in
|
||||
both but have new sub-fields in *desired*; each entry is a
|
||||
copy of the existing field with the missing sub-fields
|
||||
appended.
|
||||
"""
|
||||
existing_by_name = {f.name: f for f in existing}
|
||||
new_fields: list[bq_schema.SchemaField] = []
|
||||
updated_records: list[bq_schema.SchemaField] = []
|
||||
|
||||
for desired_field in desired:
|
||||
existing_field = existing_by_name.get(desired_field.name)
|
||||
if existing_field is None:
|
||||
new_fields.append(desired_field)
|
||||
elif (
|
||||
desired_field.field_type == "RECORD"
|
||||
and existing_field.field_type == "RECORD"
|
||||
and desired_field.fields
|
||||
):
|
||||
# Recurse into nested RECORD fields.
|
||||
sub_new, sub_updated = (
|
||||
BigQueryAgentAnalyticsPlugin._schema_fields_match(
|
||||
list(existing_field.fields),
|
||||
list(desired_field.fields),
|
||||
)
|
||||
)
|
||||
if sub_new or sub_updated:
|
||||
# Build a merged sub-field list.
|
||||
merged_sub = list(existing_field.fields)
|
||||
# Replace updated nested records in-place.
|
||||
updated_names = {f.name for f in sub_updated}
|
||||
merged_sub = [
|
||||
next(u for u in sub_updated if u.name == f.name)
|
||||
if f.name in updated_names
|
||||
else f
|
||||
for f in merged_sub
|
||||
]
|
||||
# Append entirely new sub-fields.
|
||||
merged_sub.extend(sub_new)
|
||||
# Rebuild via API representation to preserve all
|
||||
# existing field attributes (policy_tags, etc.).
|
||||
api_repr = existing_field.to_api_repr()
|
||||
api_repr["fields"] = [sf.to_api_repr() for sf in merged_sub]
|
||||
updated_records.append(bq_schema.SchemaField.from_api_repr(api_repr))
|
||||
|
||||
return new_fields, updated_records
|
||||
|
||||
def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
|
||||
"""Adds missing columns to an existing table (additive only).
|
||||
|
||||
Handles nested RECORD fields by recursing into sub-fields.
|
||||
The version label is only stamped after a successful update
|
||||
so that a failed attempt is retried on the next run.
|
||||
|
||||
Args:
|
||||
existing_table: The current BigQuery table object.
|
||||
"""
|
||||
@@ -2154,24 +2239,43 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
if stored_version == _SCHEMA_VERSION:
|
||||
return
|
||||
|
||||
existing_names = {f.name for f in existing_table.schema}
|
||||
new_fields = [f for f in self._schema if f.name not in existing_names]
|
||||
new_fields, updated_records = self._schema_fields_match(
|
||||
list(existing_table.schema), list(self._schema)
|
||||
)
|
||||
|
||||
if new_fields:
|
||||
merged = list(existing_table.schema) + new_fields
|
||||
if new_fields or updated_records:
|
||||
# Build merged top-level schema.
|
||||
updated_names = {f.name for f in updated_records}
|
||||
merged = [
|
||||
next(u for u in updated_records if u.name == f.name)
|
||||
if f.name in updated_names
|
||||
else f
|
||||
for f in existing_table.schema
|
||||
]
|
||||
merged.extend(new_fields)
|
||||
existing_table.schema = merged
|
||||
|
||||
change_desc = []
|
||||
if new_fields:
|
||||
change_desc.append(f"new columns {[f.name for f in new_fields]}")
|
||||
if updated_records:
|
||||
change_desc.append(
|
||||
f"updated RECORD fields {[f.name for f in updated_records]}"
|
||||
)
|
||||
logger.info(
|
||||
"Auto-upgrading table %s: adding columns %s",
|
||||
"Auto-upgrading table %s: %s",
|
||||
self.full_table_id,
|
||||
[f.name for f in new_fields],
|
||||
", ".join(change_desc),
|
||||
)
|
||||
|
||||
# Always stamp the version label so we skip on next run.
|
||||
labels = dict(existing_table.labels or {})
|
||||
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
|
||||
existing_table.labels = labels
|
||||
|
||||
try:
|
||||
# Stamp the version label inside the try block so that
|
||||
# on failure the label is NOT persisted and the next run
|
||||
# retries the upgrade.
|
||||
labels = dict(existing_table.labels or {})
|
||||
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
|
||||
existing_table.labels = labels
|
||||
|
||||
update_fields = ["schema", "labels"]
|
||||
self.client.update_table(existing_table, update_fields)
|
||||
except Exception as e:
|
||||
@@ -2243,6 +2347,22 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
if loop in self._loop_state_by_loop:
|
||||
await self._loop_state_by_loop[loop].batch_processor.shutdown(timeout=t)
|
||||
|
||||
# 1b. Drain batch processors on other (non-current) loops.
|
||||
for other_loop, state in self._loop_state_by_loop.items():
|
||||
if other_loop is loop or other_loop.is_closed():
|
||||
continue
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
state.batch_processor.shutdown(timeout=t),
|
||||
other_loop,
|
||||
)
|
||||
future.result(timeout=t)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not drain batch processor on loop %s",
|
||||
other_loop,
|
||||
)
|
||||
|
||||
# 2. Close clients for all states
|
||||
for state in self._loop_state_by_loop.values():
|
||||
if state.write_client and getattr(
|
||||
@@ -2298,6 +2418,38 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
process. Pure-data fields like ``_schema`` and
|
||||
``arrow_schema`` are kept because they are safe across fork.
|
||||
"""
|
||||
logger.warning(
|
||||
"Fork detected (parent PID %s, child PID %s). Resetting"
|
||||
" gRPC state for BigQuery analytics plugin. Note: gRPC"
|
||||
" bidirectional streaming (used by the BigQuery Storage"
|
||||
" Write API) is not fork-safe. If writes hang or time"
|
||||
" out, configure the 'spawn' start method at your program"
|
||||
" entry-point before creating child processes:"
|
||||
" multiprocessing.set_start_method('spawn')",
|
||||
self._init_pid,
|
||||
os.getpid(),
|
||||
)
|
||||
# Best-effort: close inherited gRPC channels so broken
|
||||
# finalizers don't interfere with newly created channels.
|
||||
# For grpc.aio channels, close() is a coroutine. We cannot
|
||||
# await here (called from sync context / fork handler), so
|
||||
# we skip async channels and only close sync ones.
|
||||
for loop_state in self._loop_state_by_loop.values():
|
||||
wc = getattr(loop_state, "write_client", None)
|
||||
transport = getattr(wc, "transport", None)
|
||||
if transport is not None:
|
||||
try:
|
||||
channel = getattr(transport, "_grpc_channel", None)
|
||||
if channel is not None and hasattr(channel, "close"):
|
||||
result = channel.close()
|
||||
# If close() returned a coroutine (grpc.aio channel),
|
||||
# discard it to avoid unawaited-coroutine warnings.
|
||||
if asyncio.iscoroutine(result):
|
||||
result.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Clear all runtime state.
|
||||
self._setup_lock = None
|
||||
self.client = None
|
||||
self._loop_state_by_loop = {}
|
||||
@@ -2442,7 +2594,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
# Include session state if non-empty (contains user-set metadata
|
||||
# like gchat thread-id, customer_id, etc.)
|
||||
if session.state:
|
||||
session_meta["state"] = dict(session.state)
|
||||
truncated_state, _ = _recursive_smart_truncate(
|
||||
dict(session.state),
|
||||
self.config.max_content_length,
|
||||
)
|
||||
session_meta["state"] = truncated_state
|
||||
attrs["session_metadata"] = session_meta
|
||||
except Exception:
|
||||
pass
|
||||
@@ -2988,6 +3144,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
"LLM_ERROR",
|
||||
callback_context,
|
||||
event_data=EventData(
|
||||
status="ERROR",
|
||||
error_message=str(error),
|
||||
latency_ms=duration,
|
||||
span_id_override=None if has_ambient else span_id,
|
||||
@@ -3110,6 +3267,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
raw_content=content_dict,
|
||||
is_truncated=is_truncated,
|
||||
event_data=EventData(
|
||||
status="ERROR",
|
||||
error_message=str(error),
|
||||
latency_ms=duration,
|
||||
span_id_override=None if has_ambient else span_id,
|
||||
|
||||
@@ -17,6 +17,7 @@ import asyncio
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.agents import base_agent
|
||||
@@ -1734,6 +1735,7 @@ class TestBigQueryAgentAnalyticsPlugin:
|
||||
_assert_common_fields(log_entry, "LLM_ERROR")
|
||||
assert log_entry["content"] is None
|
||||
assert log_entry["error_message"] == "LLM failed"
|
||||
assert log_entry["status"] == "ERROR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_tool_error_callback_logs_correctly(
|
||||
@@ -1761,6 +1763,7 @@ class TestBigQueryAgentAnalyticsPlugin:
|
||||
assert content_dict["tool"] == "MyTool"
|
||||
assert content_dict["args"] == {"param": "value"}
|
||||
assert log_entry["error_message"] == "Tool timed out"
|
||||
assert log_entry["status"] == "ERROR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_creation_options(
|
||||
@@ -4829,7 +4832,6 @@ class TestForkSafety:
|
||||
# _ensure_started should detect PID mismatch and reset
|
||||
await plugin._ensure_started()
|
||||
# After reset + re-init, _init_pid should match current
|
||||
import os
|
||||
|
||||
assert plugin._init_pid == os.getpid()
|
||||
assert plugin._started is True
|
||||
@@ -4884,8 +4886,6 @@ class TestForkSafety:
|
||||
assert plugin._schema == ["kept"]
|
||||
assert plugin.arrow_schema == "kept_arrow"
|
||||
|
||||
import os
|
||||
|
||||
assert plugin._init_pid == os.getpid()
|
||||
|
||||
def test_getstate_resets_pid(self):
|
||||
@@ -4920,6 +4920,134 @@ class TestForkSafety:
|
||||
await new_plugin.shutdown()
|
||||
|
||||
|
||||
class TestForkGrpcSafety:
|
||||
"""Tests for gRPC fork safety enhancements."""
|
||||
|
||||
def _make_plugin(self):
|
||||
config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig()
|
||||
return bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
|
||||
project_id=PROJECT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
table_id=TABLE_ID,
|
||||
config=config,
|
||||
)
|
||||
|
||||
def test_grpc_fork_env_var_set(self):
|
||||
"""GRPC_ENABLE_FORK_SUPPORT should be '1' after import."""
|
||||
|
||||
assert os.environ.get("GRPC_ENABLE_FORK_SUPPORT") == "1"
|
||||
|
||||
def test_register_at_fork_resets_all_instances(self):
|
||||
"""_after_fork_in_child resets all living plugin instances."""
|
||||
p1 = self._make_plugin()
|
||||
p2 = self._make_plugin()
|
||||
p1._started = True
|
||||
p2._started = True
|
||||
p1._init_pid = -1
|
||||
p2._init_pid = -1
|
||||
|
||||
bigquery_agent_analytics_plugin._after_fork_in_child()
|
||||
|
||||
assert p1._started is False
|
||||
assert p2._started is False
|
||||
assert p1._init_pid == os.getpid()
|
||||
assert p2._init_pid == os.getpid()
|
||||
|
||||
def test_dead_plugin_removed_from_live_set(self):
|
||||
"""WeakSet should not hold dead plugin references."""
|
||||
p = self._make_plugin()
|
||||
assert p in bigquery_agent_analytics_plugin._LIVE_PLUGINS
|
||||
pid = id(p)
|
||||
del p
|
||||
# After deletion, the WeakSet should no longer contain it.
|
||||
for alive in bigquery_agent_analytics_plugin._LIVE_PLUGINS:
|
||||
assert id(alive) != pid
|
||||
|
||||
def test_reset_closes_inherited_sync_transports(self):
|
||||
"""_reset_runtime_state closes inherited sync gRPC channels."""
|
||||
plugin = self._make_plugin()
|
||||
mock_channel = mock.MagicMock()
|
||||
mock_channel.close.return_value = None # sync close
|
||||
mock_transport = mock.MagicMock()
|
||||
mock_transport._grpc_channel = mock_channel
|
||||
mock_wc = mock.MagicMock()
|
||||
mock_wc.transport = mock_transport
|
||||
|
||||
mock_loop_state = mock.MagicMock()
|
||||
mock_loop_state.write_client = mock_wc
|
||||
|
||||
plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state}
|
||||
plugin._init_pid = -1
|
||||
|
||||
plugin._reset_runtime_state()
|
||||
|
||||
mock_channel.close.assert_called_once()
|
||||
|
||||
def test_reset_discards_async_channel_close_coroutine(self):
|
||||
"""Async channel close() returns a coroutine; must not warn."""
|
||||
import warnings
|
||||
|
||||
plugin = self._make_plugin()
|
||||
|
||||
async def _async_close():
|
||||
pass
|
||||
|
||||
mock_channel = mock.MagicMock()
|
||||
mock_channel.close.return_value = _async_close()
|
||||
mock_transport = mock.MagicMock()
|
||||
mock_transport._grpc_channel = mock_channel
|
||||
mock_wc = mock.MagicMock()
|
||||
mock_wc.transport = mock_transport
|
||||
|
||||
mock_loop_state = mock.MagicMock()
|
||||
mock_loop_state.write_client = mock_wc
|
||||
|
||||
plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state}
|
||||
plugin._init_pid = -1
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
# Must not raise RuntimeWarning for unawaited coroutine
|
||||
plugin._reset_runtime_state()
|
||||
|
||||
mock_channel.close.assert_called_once()
|
||||
|
||||
def test_transport_close_exception_swallowed(self):
|
||||
"""close() raising should not prevent reset from completing."""
|
||||
plugin = self._make_plugin()
|
||||
mock_channel = mock.MagicMock()
|
||||
mock_channel.close.side_effect = RuntimeError("broken channel")
|
||||
mock_transport = mock.MagicMock()
|
||||
mock_transport._grpc_channel = mock_channel
|
||||
mock_wc = mock.MagicMock()
|
||||
mock_wc.transport = mock_transport
|
||||
|
||||
mock_loop_state = mock.MagicMock()
|
||||
mock_loop_state.write_client = mock_wc
|
||||
|
||||
plugin._loop_state_by_loop = {mock.MagicMock(): mock_loop_state}
|
||||
plugin._init_pid = -1
|
||||
|
||||
# Should not raise
|
||||
plugin._reset_runtime_state()
|
||||
|
||||
assert plugin._started is False
|
||||
assert plugin._loop_state_by_loop == {}
|
||||
|
||||
def test_reset_logs_fork_warning(self):
|
||||
"""_reset_runtime_state logs a warning with 'Fork detected'."""
|
||||
plugin = self._make_plugin()
|
||||
plugin._init_pid = -1
|
||||
|
||||
with mock.patch.object(
|
||||
bigquery_agent_analytics_plugin.logger, "warning"
|
||||
) as mock_warn:
|
||||
plugin._reset_runtime_state()
|
||||
|
||||
mock_warn.assert_called_once()
|
||||
assert "Fork detected" in mock_warn.call_args[0][0]
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Analytics Views Tests
|
||||
# ==============================================================================
|
||||
@@ -6057,3 +6185,270 @@ class TestAfterRunCleanupExceptionSafety:
|
||||
assert bigquery_agent_analytics_plugin._root_agent_name_ctx.get() is None
|
||||
|
||||
provider.shutdown()
|
||||
|
||||
|
||||
class TestStringSystemPromptTruncation:
|
||||
"""Tests that a string system prompt is truncated in parse()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_string_system_prompt_is_truncated(self):
|
||||
"""A string system_instruction exceeding max_content_length is truncated."""
|
||||
parser = bigquery_agent_analytics_plugin.HybridContentParser(
|
||||
offloader=None,
|
||||
trace_id="test-trace",
|
||||
span_id="test-span",
|
||||
max_length=50,
|
||||
)
|
||||
long_prompt = "A" * 200
|
||||
llm_request = llm_request_lib.LlmRequest(
|
||||
model="gemini-pro",
|
||||
contents=[types.Content(parts=[types.Part(text="Hi")])],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=long_prompt,
|
||||
),
|
||||
)
|
||||
payload, _, is_truncated = await parser.parse(llm_request)
|
||||
assert is_truncated
|
||||
assert len(payload["system_prompt"]) < 200
|
||||
assert "TRUNCATED" in payload["system_prompt"]
|
||||
|
||||
|
||||
class TestSessionStateTruncation:
|
||||
"""Tests that session state is truncated in _enrich_attributes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_session_state_is_truncated(
|
||||
self,
|
||||
mock_auth_default,
|
||||
mock_bq_client,
|
||||
mock_write_client,
|
||||
mock_to_arrow_schema,
|
||||
mock_asyncio_to_thread,
|
||||
mock_session,
|
||||
invocation_context,
|
||||
):
|
||||
"""Session state with large values is truncated."""
|
||||
config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig(
|
||||
max_content_length=30,
|
||||
)
|
||||
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
|
||||
project_id=PROJECT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
table_id=TABLE_ID,
|
||||
config=config,
|
||||
)
|
||||
await plugin._ensure_started()
|
||||
|
||||
# Set a large session state value.
|
||||
large_value = "X" * 200
|
||||
type(mock_session).state = mock.PropertyMock(
|
||||
return_value={"big_key": large_value}
|
||||
)
|
||||
|
||||
callback_ctx = CallbackContext(invocation_context=invocation_context)
|
||||
event_data = bigquery_agent_analytics_plugin.EventData()
|
||||
attrs = plugin._enrich_attributes(event_data, callback_ctx)
|
||||
state = attrs["session_metadata"]["state"]
|
||||
assert len(state["big_key"]) < 200
|
||||
assert "TRUNCATED" in state["big_key"]
|
||||
await plugin.shutdown()
|
||||
|
||||
|
||||
class TestSchemaUpgradeNestedFields:
|
||||
"""Tests for nested RECORD field detection in schema upgrade."""
|
||||
|
||||
def _make_plugin(self):
|
||||
config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig(
|
||||
auto_schema_upgrade=True,
|
||||
)
|
||||
with mock.patch("google.cloud.bigquery.Client"):
|
||||
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
|
||||
project_id=PROJECT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
table_id=TABLE_ID,
|
||||
config=config,
|
||||
)
|
||||
plugin.client = mock.MagicMock()
|
||||
plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}"
|
||||
return plugin
|
||||
|
||||
def test_nested_field_detected(self):
|
||||
"""A new sub-field in a RECORD triggers an upgrade."""
|
||||
plugin = self._make_plugin()
|
||||
|
||||
existing_record = bigquery.SchemaField(
|
||||
"metadata",
|
||||
"RECORD",
|
||||
fields=[
|
||||
bigquery.SchemaField("key", "STRING"),
|
||||
],
|
||||
)
|
||||
desired_record = bigquery.SchemaField(
|
||||
"metadata",
|
||||
"RECORD",
|
||||
fields=[
|
||||
bigquery.SchemaField("key", "STRING"),
|
||||
bigquery.SchemaField("value", "STRING"),
|
||||
],
|
||||
)
|
||||
plugin._schema = [
|
||||
bigquery.SchemaField("timestamp", "TIMESTAMP"),
|
||||
desired_record,
|
||||
]
|
||||
|
||||
existing = mock.MagicMock(spec=bigquery.Table)
|
||||
existing.schema = [
|
||||
bigquery.SchemaField("timestamp", "TIMESTAMP"),
|
||||
existing_record,
|
||||
]
|
||||
existing.labels = {}
|
||||
plugin.client.get_table.return_value = existing
|
||||
plugin._ensure_schema_exists()
|
||||
|
||||
plugin.client.update_table.assert_called_once()
|
||||
updated_table = plugin.client.update_table.call_args[0][0]
|
||||
# Find the metadata field and check it has both sub-fields.
|
||||
metadata_field = next(
|
||||
f for f in updated_table.schema if f.name == "metadata"
|
||||
)
|
||||
sub_names = {sf.name for sf in metadata_field.fields}
|
||||
assert "key" in sub_names
|
||||
assert "value" in sub_names
|
||||
|
||||
def test_version_label_not_stamped_on_failure(self):
|
||||
"""A failed update_table does not persist the version label."""
|
||||
plugin = self._make_plugin()
|
||||
plugin._schema = [
|
||||
bigquery.SchemaField("timestamp", "TIMESTAMP"),
|
||||
bigquery.SchemaField("new_col", "STRING"),
|
||||
]
|
||||
|
||||
existing = mock.MagicMock(spec=bigquery.Table)
|
||||
existing.schema = [
|
||||
bigquery.SchemaField("timestamp", "TIMESTAMP"),
|
||||
]
|
||||
existing.labels = {}
|
||||
plugin.client.get_table.return_value = existing
|
||||
plugin.client.update_table.side_effect = Exception("network error")
|
||||
|
||||
# Should not raise.
|
||||
plugin._ensure_schema_exists()
|
||||
|
||||
# The label is set on the table object before update_table is
|
||||
# called, but since update_table failed the label was never
|
||||
# persisted remotely. On the next run the stored_version will
|
||||
# still be None (from the real BQ table) so the upgrade retries.
|
||||
# We verify that update_table was actually attempted.
|
||||
plugin.client.update_table.assert_called_once()
|
||||
|
||||
def test_nested_upgrade_preserves_policy_tags(self):
|
||||
"""RECORD field metadata (e.g. policy_tags) is preserved on upgrade."""
|
||||
from google.cloud.bigquery import schema as bq_schema
|
||||
|
||||
plugin = self._make_plugin()
|
||||
|
||||
existing_record = bigquery.SchemaField(
|
||||
"metadata",
|
||||
"RECORD",
|
||||
policy_tags=bq_schema.PolicyTagList(
|
||||
names=["projects/p/locations/us/taxonomies/t/policyTags/pt"]
|
||||
),
|
||||
fields=[
|
||||
bigquery.SchemaField("key", "STRING"),
|
||||
],
|
||||
)
|
||||
desired_record = bigquery.SchemaField(
|
||||
"metadata",
|
||||
"RECORD",
|
||||
fields=[
|
||||
bigquery.SchemaField("key", "STRING"),
|
||||
bigquery.SchemaField("value", "STRING"),
|
||||
],
|
||||
)
|
||||
plugin._schema = [
|
||||
bigquery.SchemaField("timestamp", "TIMESTAMP"),
|
||||
desired_record,
|
||||
]
|
||||
|
||||
existing = mock.MagicMock(spec=bigquery.Table)
|
||||
existing.schema = [
|
||||
bigquery.SchemaField("timestamp", "TIMESTAMP"),
|
||||
existing_record,
|
||||
]
|
||||
existing.labels = {}
|
||||
plugin.client.get_table.return_value = existing
|
||||
plugin._ensure_schema_exists()
|
||||
|
||||
plugin.client.update_table.assert_called_once()
|
||||
updated_table = plugin.client.update_table.call_args[0][0]
|
||||
metadata_field = next(
|
||||
f for f in updated_table.schema if f.name == "metadata"
|
||||
)
|
||||
# Sub-fields were merged.
|
||||
sub_names = {sf.name for sf in metadata_field.fields}
|
||||
assert "key" in sub_names
|
||||
assert "value" in sub_names
|
||||
# policy_tags preserved from the existing field.
|
||||
assert metadata_field.policy_tags is not None
|
||||
assert (
|
||||
"projects/p/locations/us/taxonomies/t/policyTags/pt"
|
||||
in metadata_field.policy_tags.names
|
||||
)
|
||||
|
||||
|
||||
class TestMultiLoopShutdownDrainsOtherLoops:
|
||||
"""Tests that shutdown() drains batch processors on other loops."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_other_loop_batch_processor_drained(
|
||||
self,
|
||||
mock_auth_default,
|
||||
mock_bq_client,
|
||||
mock_write_client,
|
||||
mock_to_arrow_schema,
|
||||
mock_asyncio_to_thread,
|
||||
):
|
||||
"""Shutdown drains batch_processor.shutdown on non-current loops."""
|
||||
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
|
||||
project_id=PROJECT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
table_id=TABLE_ID,
|
||||
)
|
||||
await plugin._ensure_started()
|
||||
|
||||
# Create a mock "other" loop with a mock batch processor.
|
||||
other_loop = mock.MagicMock(spec=asyncio.AbstractEventLoop)
|
||||
other_loop.is_closed.return_value = False
|
||||
|
||||
mock_other_bp = mock.AsyncMock()
|
||||
mock_other_write_client = mock.MagicMock()
|
||||
mock_other_write_client.transport = mock.AsyncMock()
|
||||
|
||||
other_state = bigquery_agent_analytics_plugin._LoopState(
|
||||
write_client=mock_other_write_client,
|
||||
batch_processor=mock_other_bp,
|
||||
)
|
||||
plugin._loop_state_by_loop[other_loop] = other_state
|
||||
|
||||
# Patch run_coroutine_threadsafe to verify it's called for
|
||||
# the other loop's batch_processor. Close the coroutine arg
|
||||
# to avoid "coroutine was never awaited" RuntimeWarning.
|
||||
mock_future = mock.MagicMock()
|
||||
mock_future.result.return_value = None
|
||||
|
||||
def _fake_run_coroutine_threadsafe(coro, loop):
|
||||
coro.close()
|
||||
return mock_future
|
||||
|
||||
with mock.patch.object(
|
||||
asyncio,
|
||||
"run_coroutine_threadsafe",
|
||||
side_effect=_fake_run_coroutine_threadsafe,
|
||||
) as mock_rcts:
|
||||
await plugin.shutdown()
|
||||
|
||||
# Verify run_coroutine_threadsafe was called with
|
||||
# the other loop.
|
||||
mock_rcts.assert_called()
|
||||
call_args = mock_rcts.call_args
|
||||
assert call_args[0][1] is other_loop
|
||||
|
||||
Reference in New Issue
Block a user