feat: Enhance BQ Plugin Schema, Error Handling, and Logging

This update enhances the BigQuery agent analytics plugin:

*   **Enhanced Error Logging:** Improved error messages for schema mismatches.
*   **Reordered Logging Content:** Prioritized metadata in `before_model_callback`.

PiperOrigin-RevId: 833508755
This commit is contained in:
Google Team Member
2025-11-17 15:00:09 -08:00
committed by Copybara-Service
parent c642f13f21
commit 5ac5129fb0
2 changed files with 221 additions and 41 deletions
@@ -437,7 +437,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
return self._config.content_formatter(content), False
return _format_content(content, max_len=self._config.max_content_length)
except Exception as e:
logging.warning(f"Content formatter failed: {e}")
logging.warning("Content formatter failed: %s", e)
return "[FORMATTING FAILED]", False
async def _ensure_init(self):
@@ -523,7 +523,21 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
self._write_client.append_rows(iter([req]))
):
if resp.error.code != 0:
logging.error(f"BQ Plugin: Write Error: {resp.error.message}")
msg = resp.error.message
# Check for common schema mismatch indicators
if (
"schema mismatch" in msg.lower()
or "field" in msg.lower()
or "type" in msg.lower()
):
logging.error(
"BQ Plugin: Schema Mismatch Error. The BigQuery table schema"
" may be incorrect or out of sync with the plugin. Please"
" verify the table definition. Details: %s",
msg,
)
else:
logging.error("BQ Plugin: Write Error: %s", msg)
except RuntimeError as e:
if "Event loop is closed" not in str(e) and not self._is_shutting_down:
@@ -578,7 +592,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
if self._background_tasks:
logging.info(
f"BQ Plugin: Flushing {len(self._background_tasks)} pending logs..."
"BQ Plugin: Flushing %s pending logs...", len(self._background_tasks)
)
try:
await asyncio.wait(
@@ -598,12 +612,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
timeout=self._config.client_close_timeout,
)
except Exception as e:
logging.warning(f"BQ Plugin: Error closing write client: {e}")
logging.warning("BQ Plugin: Error closing write client: %s", e)
if self._bq_client:
try:
self._bq_client.close()
except Exception as e:
logging.warning(f"BQ Plugin: Error closing BQ client: {e}")
logging.warning("BQ Plugin: Error closing BQ client: %s", e)
self._write_client = None
self._bq_client = None
@@ -617,7 +631,14 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
invocation_context: InvocationContext,
user_message: types.Content,
) -> None:
"""Callback for user messages."""
"""Callback for user messages.
Logs the user message details including:
1. User content (text)
The content is formatted as 'User Content: {content}'.
If the content length exceeds `max_content_length`, it is truncated.
"""
content, truncated = self._format_content_safely(user_message)
await self._log({
"event_type": "USER_MESSAGE_RECEIVED",
@@ -632,7 +653,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def before_run_callback(
self, *, invocation_context: InvocationContext
) -> None:
"""Callback before agent invocation."""
"""Callback before agent invocation.
Logs the start of an agent invocation.
No specific content payload is logged for this event, but standard metadata
(agent name, session ID, invocation ID, user ID) is captured.
"""
await self._log({
"event_type": "INVOCATION_STARTING",
"agent": invocation_context.agent.name,
@@ -644,7 +670,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> None:
"""Callback for agent events."""
"""Callback for agent events.
Logs generic agent events including:
1. Event type (determined from event properties)
2. Event content (text, function calls, or responses)
3. Error messages (if any)
The content is formatted based on the event type.
If the content length exceeds `max_content_length`, it is truncated.
"""
content, truncated = self._format_content_safely(event.content)
await self._log({
"event_type": _get_event_type(event),
@@ -661,7 +696,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def after_run_callback(
self, *, invocation_context: InvocationContext
) -> None:
"""Callback after agent invocation."""
"""Callback after agent invocation.
Logs the completion of an agent invocation.
No specific content payload is logged for this event, but standard metadata
(agent name, session ID, invocation ID, user ID) is captured.
"""
await self._log({
"event_type": "INVOCATION_COMPLETED",
"agent": invocation_context.agent.name,
@@ -673,7 +713,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def before_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> None:
"""Callback before an agent starts."""
"""Callback before an agent starts.
Logs the start of a specific agent execution.
Content includes:
1. Agent Name (from callback context)
"""
await self._log({
"event_type": "AGENT_STARTING",
"agent": agent.name,
@@ -686,7 +731,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def after_agent_callback(
self, *, agent: BaseAgent, callback_context: CallbackContext
) -> None:
"""Callback after an agent completes."""
"""Callback after an agent completes.
Logs the completion of a specific agent execution.
Content includes:
1. Agent Name (from callback context)
"""
await self._log({
"event_type": "AGENT_COMPLETED",
"agent": agent.name,
@@ -699,11 +749,52 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> None:
"""Callback before LLM call."""
"""Callback before LLM call.
Logs the LLM request details including:
1. Model name
2. Configuration parameters (temperature, top_p, top_k, max_output_tokens)
3. Available tool names
4. Prompt content (user/model messages)
5. System instructions
The content is formatted as a single string with fields separated by ' | '.
If the total length exceeds `max_content_length`, the string is truncated,
prioritizing the metadata (Model, Params, Tools) over the Prompt and System
Prompt.
"""
content_parts = [
f"Model: {llm_request.model or 'default'}",
]
is_truncated = False
# 1. Params
if llm_request.config:
config = llm_request.config
params_to_log = {}
if hasattr(config, "temperature") and config.temperature is not None:
params_to_log["temperature"] = config.temperature
if hasattr(config, "top_p") and config.top_p is not None:
params_to_log["top_p"] = config.top_p
if hasattr(config, "top_k") and config.top_k is not None:
params_to_log["top_k"] = config.top_k
if (
hasattr(config, "max_output_tokens")
and config.max_output_tokens is not None
):
params_to_log["max_output_tokens"] = config.max_output_tokens
if params_to_log:
params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()])
content_parts.append(f"Params: {{{params_str}}}")
# 2. Tools
if llm_request.tools_dict:
content_parts.append(
f"Available Tools: {list(llm_request.tools_dict.keys())}"
)
# 3. Prompt
if contents := getattr(llm_request, "contents", None):
prompt_parts = []
for c in contents:
@@ -713,6 +804,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
is_truncated = True
prompt_str = " | ".join(prompt_parts)
content_parts.append(f"Prompt: {prompt_str}")
# 4. System Prompt
system_instruction_text = "None"
if llm_request.config and llm_request.config.system_instruction:
si = llm_request.config.system_instruction
@@ -736,29 +829,6 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
system_instruction_text = "Empty"
content_parts.append(f"System Prompt: {system_instruction_text}")
if llm_request.config:
config = llm_request.config
params_to_log = {}
if hasattr(config, "temperature") and config.temperature is not None:
params_to_log["temperature"] = config.temperature
if hasattr(config, "top_p") and config.top_p is not None:
params_to_log["top_p"] = config.top_p
if hasattr(config, "top_k") and config.top_k is not None:
params_to_log["top_k"] = config.top_k
if (
hasattr(config, "max_output_tokens")
and config.max_output_tokens is not None
):
params_to_log["max_output_tokens"] = config.max_output_tokens
if params_to_log:
params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()])
content_parts.append(f"Params: {{{params_str}}}")
if llm_request.tools_dict:
content_parts.append(
f"Available Tools: {list(llm_request.tools_dict.keys())}"
)
final_content = " | ".join(content_parts)
max_len = self._config.max_content_length
@@ -778,7 +848,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def after_model_callback(
self, *, callback_context: CallbackContext, llm_response: LlmResponse
) -> None:
"""Callback after LLM call."""
"""Callback after LLM call.
Logs the LLM response details including:
1. Tool calls (if any)
2. Text response (if no tool calls)
3. Token usage statistics (prompt, candidates, total)
The content is formatted as a single string with fields separated by ' | '.
If the content length exceeds `max_content_length`, it is truncated.
"""
content_parts = []
content = llm_response.content
is_tool_call = False
@@ -838,7 +917,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
tool_args: dict[str, Any],
tool_context: ToolContext,
) -> None:
"""Callback before tool call."""
"""Callback before tool call.
Logs the tool execution start details including:
1. Tool name
2. Tool description
3. Tool arguments
The content is formatted as 'Tool Name: ..., Description: ..., Arguments:
...'.
If the content length exceeds `max_content_length`, it is truncated.
"""
args_str, truncated = _format_args(
tool_args, max_len=self._config.max_content_length
)
@@ -867,7 +956,15 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
tool_context: ToolContext,
result: dict[str, Any],
) -> None:
"""Callback after tool call."""
"""Callback after tool call.
Logs the tool execution result details including:
1. Tool name
2. Tool result
The content is formatted as 'Tool Name: ..., Result: ...'.
If the content length exceeds `max_content_length`, it is truncated.
"""
result_str, truncated = _format_args(
result, max_len=self._config.max_content_length
)
@@ -892,7 +989,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
llm_request: LlmRequest,
error: Exception,
) -> None:
"""Callback for LLM errors."""
"""Callback for model errors.
Logs errors that occur during LLM calls.
No specific content payload is logged, but the error message is captured
in the `error_message` field.
"""
await self._log({
"event_type": "LLM_ERROR",
"agent": callback_context.agent_name,
@@ -910,7 +1012,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
tool_context: ToolContext,
error: Exception,
) -> None:
"""Callback for tool errors."""
"""Callback for tool errors.
Logs errors that occur during tool execution.
Content includes:
1. Tool name
2. Tool arguments
The error message is captured in the `error_message` field.
If the content length exceeds `max_content_length`, it is truncated.
"""
args_str, truncated = _format_args(
tool_args, max_len=self._config.max_content_length
)
@@ -695,9 +695,42 @@ class TestBigQueryAgentAnalyticsPlugin:
user_message=types.Content(parts=[types.Part(text="Test")]),
)
await asyncio.sleep(0.01)
mock_log_error.assert_called_with("BQ Plugin: Write Error: Test BQ Error")
mock_log_error.assert_called_with(
"BQ Plugin: Write Error: %s", "Test BQ Error"
)
mock_write_client.append_rows.assert_called_once()
@pytest.mark.asyncio
async def test_schema_mismatch_error_handling(
self, bq_plugin_inst, mock_write_client, invocation_context
):
async def fake_append_rows_with_schema_error(requests, **kwargs):
mock_resp = mock.MagicMock()
mock_resp.row_errors = []
mock_resp.error = mock.MagicMock()
mock_resp.error.code = 3
mock_resp.error.message = (
"Schema mismatch: Field 'new_field' not found in table."
)
return _async_gen(mock_resp)
mock_write_client.append_rows.side_effect = (
fake_append_rows_with_schema_error
)
with mock.patch.object(logging, "error") as mock_log_error:
await bq_plugin_inst.on_user_message_callback(
invocation_context=invocation_context,
user_message=types.Content(parts=[types.Part(text="Test")]),
)
await asyncio.sleep(0.01)
mock_log_error.assert_called_with(
"BQ Plugin: Schema Mismatch Error. The BigQuery table schema may be"
" incorrect or out of sync with the plugin. Please verify the table"
" definition. Details: %s",
"Schema mismatch: Field 'new_field' not found in table.",
)
@pytest.mark.asyncio
async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client):
await bq_plugin_inst.close()
@@ -801,6 +834,42 @@ class TestBigQueryAgentAnalyticsPlugin:
" Empty"
)
@pytest.mark.asyncio
async def test_before_model_callback_with_params_and_tools(
self,
bq_plugin_inst,
mock_write_client,
callback_context,
dummy_arrow_schema,
):
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
config=types.GenerateContentConfig(
temperature=0.5,
top_p=0.9,
system_instruction=types.Content(parts=[types.Part(text="Sys")]),
),
contents=[types.Content(role="user", parts=[types.Part(text="User")])],
)
# Manually set tools_dict as it is excluded from init
llm_request.tools_dict = {"tool1": "func1", "tool2": "func2"}
await bq_plugin_inst.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
await asyncio.sleep(0.01)
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
_assert_common_fields(log_entry, "LLM_REQUEST")
# Order: Model | Params | Tools | Prompt | System Prompt
# Note: Params order depends on dict iteration but here we construct it deterministically in code?
# The code does: params_to_log["temperature"] = ... then "top_p" = ...
# So order should be temperature, top_p.
assert "Model: gemini-pro" in log_entry["content"]
assert "Params: {temperature=0.5, top_p=0.9}" in log_entry["content"]
assert "Available Tools: ['tool1', 'tool2']" in log_entry["content"]
assert "Prompt: user: text: 'User'" in log_entry["content"]
assert "System Prompt: Sys" in log_entry["content"]
@pytest.mark.asyncio
async def test_after_model_callback_text_response(
self,