fix: Enhance BigQuery Plugin Robustness and Schema Accuracy

This update improves the `BigQueryAgentAnalyticsPlugin` in several ways:

*   Corrects the PyArrow schema generation to accurately reflect BigQuery field nullability based on the `mode` attribute.
*   Introduces a configurable `shutdown_timeout` in `BigQueryLoggerConfig` to manage how long the plugin waits for pending logs to flush during shutdown.
*   Adds more robust error handling within the `shutdown` method and background write tasks, particularly for event loop closure issues.
*   Improves internal logging to provide better diagnostics.
*   Ensures consistent use of safe content formatting.

PiperOrigin-RevId: 831225837
This commit is contained in:
Google Team Member
2025-11-11 22:46:39 -08:00
committed by Copybara-Service
parent a501c59ac4
commit 37ee1869b5
2 changed files with 204 additions and 64 deletions
@@ -23,7 +23,6 @@ from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Set
from typing import TYPE_CHECKING
from google.api_core.gapic_v1 import client_info as gapic_client_info
@@ -173,10 +172,11 @@ def _bq_to_arrow_field(bq_field):
metadata = _BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get(
bq_field.field_type.upper() if bq_field.field_type else ""
)
nullable = bq_field.mode.upper() != "REQUIRED"
return pa.field(
bq_field.name,
arrow_type,
nullable=(bq_field.mode != "REPEATED"),
nullable=nullable,
metadata=metadata,
)
logging.warning(
@@ -213,12 +213,18 @@ class BigQueryLoggerConfig:
event_denylist: A list of event types to skip logging.
content_formatter: An optional function to format event content before
logging.
shutdown_timeout: Seconds to wait for logs to flush during shutdown.
client_close_timeout: Seconds to wait for BQ client to close.
max_content_length: The maximum length of content parts before truncation.
"""
enabled: bool = True
event_allowlist: Optional[List[str]] = None
event_denylist: Optional[List[str]] = None
content_formatter: Optional[Callable[[Any], str]] = None
shutdown_timeout: float = 5.0
client_close_timeout: float = 2.0
max_content_length: int = 500
# --- Helper Formatters ---
@@ -313,16 +319,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
self._write_client: BigQueryWriteAsyncClient | None = None
self._init_lock: asyncio.Lock | None = None
self._arrow_schema: pa.Schema | None = None
self._background_tasks: Set[asyncio.Task] = set() # Track pending logs
self._background_tasks: set[asyncio.Task] = set()
self._is_shutting_down = False
self._schema = [
bigquery.SchemaField("timestamp", "TIMESTAMP"),
bigquery.SchemaField("event_type", "STRING"),
bigquery.SchemaField("agent", "STRING"),
bigquery.SchemaField("session_id", "STRING"),
bigquery.SchemaField("invocation_id", "STRING"),
bigquery.SchemaField("user_id", "STRING"),
bigquery.SchemaField("content", "STRING"),
bigquery.SchemaField("error_message", "STRING"),
bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"),
bigquery.SchemaField("event_type", "STRING", mode="NULLABLE"),
bigquery.SchemaField("agent", "STRING", mode="NULLABLE"),
bigquery.SchemaField("session_id", "STRING", mode="NULLABLE"),
bigquery.SchemaField("invocation_id", "STRING", mode="NULLABLE"),
bigquery.SchemaField("user_id", "STRING", mode="NULLABLE"),
bigquery.SchemaField("content", "STRING", mode="NULLABLE"),
bigquery.SchemaField("error_message", "STRING", mode="NULLABLE"),
]
def _format_content_safely(
@@ -334,7 +341,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
try:
if self._config.content_formatter:
return self._config.content_formatter(content)
return _format_content(content)
return _format_content(content, max_len=self._config.max_content_length)
except Exception as e:
logging.warning(f"Content formatter failed: {e}")
return "[FORMATTING FAILED]"
@@ -363,14 +370,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
# Ensure table exists (sync call in thread)
def create_resources():
if self._bq_client:
dataset = self._bq_client.create_dataset(
self._dataset_id, exists_ok=True
)
self._bq_client.create_dataset(self._dataset_id, exists_ok=True)
table = bigquery.Table(
f"{self._project_id}.{self._dataset_id}.{self._table_id}",
schema=self._schema,
)
self._bq_client.create_table(table, exists_ok=True)
logging.info(
"BQ Plugin: Dataset %s and Table %s ensured to exist.",
self._dataset_id,
self._table_id,
)
await asyncio.to_thread(create_resources)
@@ -379,9 +389,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
client_info=client_info,
)
self._arrow_schema = to_arrow_schema(self._schema)
if not self._arrow_schema:
raise RuntimeError("Failed to convert BigQuery schema to Arrow.")
logging.info("BQ Plugin: Initialized successfully.")
return True
except Exception as e:
logging.error(f"BQ Init Failed: {e}")
logging.error("BQ Plugin: Init Failed:", exc_info=True)
return False
async def _perform_write(self, row: dict):
@@ -412,14 +425,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
self._write_client.append_rows(iter([req]))
):
if resp.error.code != 0:
logging.error(f"BQ Write Error: {resp.error.message}")
logging.error(f"BQ Plugin: Write Error: {resp.error.message}")
except RuntimeError as e:
# Silently ignore event loop closed errors during background writes
if "Event loop is closed" not in str(e):
logging.exception(f"BQ Runtime Error: {e}")
if "Event loop is closed" not in str(e) and not self._is_shutting_down:
logging.error("BQ Plugin: Runtime Error during write:", exc_info=True)
except asyncio.CancelledError:
if not self._is_shutting_down:
logging.warning("BQ Plugin: Write task cancelled unexpectedly.")
except Exception as e:
logging.error(f"BQ Write Failed: {e}")
logging.error("BQ Plugin: Write Failed:", exc_info=True)
async def _log(self, data: dict):
"""Schedules a log entry to be written in the background."""
@@ -457,32 +472,44 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def close(self):
"""Flushes pending logs and closes client."""
# 1. Wait for pending background logs (best effort, 2s timeout)
if self._background_tasks:
logging.info(f"Flushing {len(self._background_tasks)} pending BQ logs...")
done, pending = await asyncio.wait(self._background_tasks, timeout=2.0)
if pending:
logging.warning(
f"{len(pending)} BQ logs could not be flushed before shutdown."
)
if self._is_shutting_down:
return
self._is_shutting_down = True
logging.info("BQ Plugin: Shutdown started.")
# 2. Close client
if self._write_client and self._write_client.transport:
if self._background_tasks:
logging.info(
f"BQ Plugin: Flushing {len(self._background_tasks)} pending logs..."
)
try:
logging.info("Closing BQ Write client transport...")
await asyncio.wait(
self._background_tasks, timeout=self._config.shutdown_timeout
)
except asyncio.TimeoutError:
logging.warning("BQ Plugin: Timeout waiting for logs to flush.")
except Exception as e:
logging.warning("BQ Plugin: Error flushing logs:", exc_info=True)
# Use getattr for safe access in case transport is not present.
if self._write_client and getattr(self._write_client, "transport", None):
try:
logging.info("BQ Plugin: Closing write client.")
await asyncio.wait_for(
self._write_client.transport.close(), timeout=1.0
self._write_client.transport.close(),
timeout=self._config.client_close_timeout,
)
except Exception as e:
logging.warning(f"Error during BQ Write client transport close: {e}")
self._write_client = None
logging.warning(f"BQ Plugin: Error closing write client: {e}")
if self._bq_client:
try:
logging.info("Closing BQ client...")
self._bq_client.close()
except Exception as e:
logging.warning(f"Error during BQ client close: {e}")
self._bq_client = None
logging.warning(f"BQ Plugin: Error closing BQ client: {e}")
self._write_client = None
self._bq_client = None
self._is_shutting_down = False
logging.info("BQ Plugin: Shutdown complete.")
# --- Streamlined Callbacks ---
async def on_user_message_callback(
@@ -523,13 +550,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
"session_id": invocation_context.session.id,
"invocation_id": invocation_context.invocation_id,
"user_id": invocation_context.session.user_id,
"content": (
json.dumps(
[part.model_dump(mode="json") for part in event.content.parts]
)
if event.content and event.content.parts
else None
),
"content": self._format_content_safely(event.content),
"error_message": event.error_message,
"timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc),
})
@@ -579,6 +600,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
content_parts = [
f"Model: {llm_request.model or 'default'}",
]
if contents := getattr(llm_request, "contents", None):
prompt_str = " | ".join(
[f"{c.role}: {self._format_content_safely(c)}" for c in contents]
)
content_parts.append(f"Prompt: {prompt_str}")
system_instruction_text = "None"
if llm_request.config and llm_request.config.system_instruction:
si = llm_request.config.system_instruction
@@ -627,6 +653,9 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
)
final_content = " | ".join(content_parts)
max_len = self._config.max_content_length
if len(final_content) > max_len:
final_content = final_content[:max_len] + "..."
await self._log({
"event_type": "LLM_REQUEST",
"agent": callback_context.agent_name,
@@ -702,7 +731,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
"user_id": tool_context.session.user_id,
"content": (
f"Tool Name: {tool.name}, Description: {tool.description},"
f" Arguments: {_format_args(tool_args)}"
" Arguments:"
f" {_format_args(tool_args, max_len=self._config.max_content_length)}"
),
})
@@ -721,7 +751,10 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
"session_id": tool_context.session.id,
"invocation_id": tool_context.invocation_id,
"user_id": tool_context.session.user_id,
"content": f"Tool Name: {tool.name}, Result: {_format_args(result)}",
"content": (
f"Tool Name: {tool.name}, Result:"
f" {_format_args(result, max_len=self._config.max_content_length)}"
),
})
async def on_model_error_callback(
@@ -757,7 +790,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
"invocation_id": tool_context.invocation_id,
"user_id": tool_context.session.user_id,
"content": (
f"Tool Name: {tool.name}, Arguments: {_format_args(tool_args)}"
f"Tool Name: {tool.name}, Arguments:"
f" {_format_args(tool_args, max_len=self._config.max_content_length)}"
),
"error_message": str(error),
})
@@ -146,14 +146,14 @@ def mock_write_client():
@pytest.fixture
def dummy_arrow_schema():
return pa.schema([
pa.field("timestamp", pa.timestamp("us", tz="UTC")),
pa.field("event_type", pa.string()),
pa.field("agent", pa.string()),
pa.field("session_id", pa.string()),
pa.field("invocation_id", pa.string()),
pa.field("user_id", pa.string()),
pa.field("content", pa.string()),
pa.field("error_message", pa.string()),
pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False),
pa.field("event_type", pa.string(), nullable=True),
pa.field("agent", pa.string(), nullable=True),
pa.field("session_id", pa.string(), nullable=True),
pa.field("invocation_id", pa.string(), nullable=True),
pa.field("user_id", pa.string(), nullable=True),
pa.field("content", pa.string(), nullable=True),
pa.field("error_message", pa.string(), nullable=True),
])
@@ -391,6 +391,102 @@ class TestBigQueryAgentAnalyticsPlugin:
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
assert log_entry["content"] == "User Content: [FORMATTING FAILED]"
@pytest.mark.asyncio
async def test_max_content_length(
self,
mock_write_client,
invocation_context,
callback_context,
mock_auth_default,
mock_bq_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
config = BigQueryLoggerConfig(max_content_length=40)
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, TABLE_ID, config
)
await plugin._ensure_init()
mock_write_client.append_rows.reset_mock()
# Test User Message Truncation
user_message = types.Content(
parts=[types.Part(text="12345678901234567890123456789012345678901")]
) # 41 chars
await plugin.on_user_message_callback(
invocation_context=invocation_context, user_message=user_message
)
await asyncio.sleep(0.01)
mock_write_client.append_rows.assert_called_once()
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
assert (
log_entry["content"]
== "User Content: text: '1234567890123456789012345678901234567890...' "
)
mock_write_client.append_rows.reset_mock()
# Test before_model_callback full content truncation
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
config=types.GenerateContentConfig(
system_instruction=types.Content(
parts=[types.Part(text="System Instruction")]
)
),
contents=[
types.Content(role="user", parts=[types.Part(text="Prompt")])
],
)
await plugin.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
await asyncio.sleep(0.01)
mock_write_client.append_rows.assert_called_once()
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
# Full content: "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt: System Instruction"
# Truncated to 40 chars + ...:
expected_content = "Model: gemini-pro | Prompt: user: text: ..."
assert log_entry["content"] == expected_content
@pytest.mark.asyncio
async def test_max_content_length_tool_args(
self,
mock_write_client,
tool_context,
mock_auth_default,
mock_bq_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
config = BigQueryLoggerConfig(max_content_length=10)
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, TABLE_ID, config
)
await plugin._ensure_init()
mock_write_client.append_rows.reset_mock()
mock_tool = mock.create_autospec(
base_tool_lib.BaseTool, instance=True, spec_set=True
)
type(mock_tool).name = mock.PropertyMock(return_value="MyTool")
type(mock_tool).description = mock.PropertyMock(return_value="Description")
# Args length > 10
# {"param": "long_value"} is ~24 chars
await plugin.before_tool_callback(
tool=mock_tool,
tool_args={"param": "long_value"},
tool_context=tool_context,
)
await asyncio.sleep(0.01)
mock_write_client.append_rows.assert_called_once()
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
# JSON string: '{"param": "long_value"}'
# Truncated to 10: '{"param": ...'
assert 'Arguments: {"param": ...' in log_entry["content"]
@pytest.mark.asyncio
async def test_on_user_message_callback_logs_correctly(
self,
@@ -430,7 +526,7 @@ class TestBigQueryAgentAnalyticsPlugin:
await asyncio.sleep(0.01)
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
_assert_common_fields(log_entry, "TOOL_CALL", agent="MyTestAgent")
assert '"name": "get_weather"' in log_entry["content"]
assert "call: get_weather" in log_entry["content"]
assert log_entry["timestamp"] == datetime.datetime(
2025, 10, 22, 10, 0, 0, tzinfo=datetime.timezone.utc
)
@@ -456,7 +552,7 @@ class TestBigQueryAgentAnalyticsPlugin:
await asyncio.sleep(0.01)
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
_assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent")
assert '"text": "Hello there!"' in log_entry["content"]
assert "text: 'Hello there!'" in log_entry["content"]
assert log_entry["timestamp"] == datetime.datetime(
2025, 10, 22, 11, 0, 0, tzinfo=datetime.timezone.utc
)
@@ -485,7 +581,7 @@ class TestBigQueryAgentAnalyticsPlugin:
user_message=types.Content(parts=[types.Part(text="Test")]),
)
await asyncio.sleep(0.01)
mock_log_error.assert_any_call("BQ Init Failed: Auth failed")
mock_log_error.assert_any_call("BQ Plugin: Init Failed:", exc_info=True)
mock_write_client.append_rows.assert_not_called()
@pytest.mark.asyncio
@@ -509,16 +605,20 @@ class TestBigQueryAgentAnalyticsPlugin:
user_message=types.Content(parts=[types.Part(text="Test")]),
)
await asyncio.sleep(0.01)
mock_log_error.assert_called_with("BQ Write Error: Test BQ Error")
mock_log_error.assert_called_with("BQ Plugin: Write Error: Test BQ Error")
mock_write_client.append_rows.assert_called_once()
@pytest.mark.asyncio
async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client):
await bq_plugin_inst.close()
mock_write_client.transport.close.assert_called_once()
mock_bq_client.close.assert_called_once()
# bq_client might not be closed if it wasn't created or if close() failed,
# but here it should be.
# in the new implementation we verify attributes are reset
assert bq_plugin_inst._write_client is None
assert bq_plugin_inst._bq_client is None
assert bq_plugin_inst._is_shutting_down is False
# ... other tests remain the same ...
@pytest.mark.asyncio
async def test_before_run_callback_logs_correctly(
self,
@@ -595,7 +695,9 @@ class TestBigQueryAgentAnalyticsPlugin:
):
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
contents=[types.Content(parts=[types.Part(text="Prompt")])],
contents=[
types.Content(role="user", parts=[types.Part(text="Prompt")])
],
)
await bq_plugin_inst.before_model_callback(
callback_context=callback_context, llm_request=llm_request
@@ -603,7 +705,11 @@ class TestBigQueryAgentAnalyticsPlugin:
await asyncio.sleep(0.01)
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
_assert_common_fields(log_entry, "LLM_REQUEST")
assert log_entry["content"] == "Model: gemini-pro | System Prompt: Empty"
assert (
log_entry["content"]
== "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt:"
" Empty"
)
@pytest.mark.asyncio
async def test_after_model_callback_text_response(