You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Enhance BQ Plugin Schema, Error Handling, and Logging
This update enhances the BigQuery agent analytics plugin: * **Enhanced Error Logging:** Improved error messages for schema mismatches. * **Reordered Logging Content:** Prioritized metadata in `before_model_callback`. PiperOrigin-RevId: 833508755
This commit is contained in:
committed by
Copybara-Service
parent
c642f13f21
commit
5ac5129fb0
@@ -437,7 +437,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
return self._config.content_formatter(content), False
|
||||
return _format_content(content, max_len=self._config.max_content_length)
|
||||
except Exception as e:
|
||||
logging.warning(f"Content formatter failed: {e}")
|
||||
logging.warning("Content formatter failed: %s", e)
|
||||
return "[FORMATTING FAILED]", False
|
||||
|
||||
async def _ensure_init(self):
|
||||
@@ -523,7 +523,21 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
self._write_client.append_rows(iter([req]))
|
||||
):
|
||||
if resp.error.code != 0:
|
||||
logging.error(f"BQ Plugin: Write Error: {resp.error.message}")
|
||||
msg = resp.error.message
|
||||
# Check for common schema mismatch indicators
|
||||
if (
|
||||
"schema mismatch" in msg.lower()
|
||||
or "field" in msg.lower()
|
||||
or "type" in msg.lower()
|
||||
):
|
||||
logging.error(
|
||||
"BQ Plugin: Schema Mismatch Error. The BigQuery table schema"
|
||||
" may be incorrect or out of sync with the plugin. Please"
|
||||
" verify the table definition. Details: %s",
|
||||
msg,
|
||||
)
|
||||
else:
|
||||
logging.error("BQ Plugin: Write Error: %s", msg)
|
||||
|
||||
except RuntimeError as e:
|
||||
if "Event loop is closed" not in str(e) and not self._is_shutting_down:
|
||||
@@ -578,7 +592,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
|
||||
if self._background_tasks:
|
||||
logging.info(
|
||||
f"BQ Plugin: Flushing {len(self._background_tasks)} pending logs..."
|
||||
"BQ Plugin: Flushing %s pending logs...", len(self._background_tasks)
|
||||
)
|
||||
try:
|
||||
await asyncio.wait(
|
||||
@@ -598,12 +612,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
timeout=self._config.client_close_timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"BQ Plugin: Error closing write client: {e}")
|
||||
logging.warning("BQ Plugin: Error closing write client: %s", e)
|
||||
if self._bq_client:
|
||||
try:
|
||||
self._bq_client.close()
|
||||
except Exception as e:
|
||||
logging.warning(f"BQ Plugin: Error closing BQ client: {e}")
|
||||
logging.warning("BQ Plugin: Error closing BQ client: %s", e)
|
||||
|
||||
self._write_client = None
|
||||
self._bq_client = None
|
||||
@@ -617,7 +631,14 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
invocation_context: InvocationContext,
|
||||
user_message: types.Content,
|
||||
) -> None:
|
||||
"""Callback for user messages."""
|
||||
"""Callback for user messages.
|
||||
|
||||
Logs the user message details including:
|
||||
1. User content (text)
|
||||
|
||||
The content is formatted as 'User Content: {content}'.
|
||||
If the content length exceeds `max_content_length`, it is truncated.
|
||||
"""
|
||||
content, truncated = self._format_content_safely(user_message)
|
||||
await self._log({
|
||||
"event_type": "USER_MESSAGE_RECEIVED",
|
||||
@@ -632,7 +653,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def before_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> None:
|
||||
"""Callback before agent invocation."""
|
||||
"""Callback before agent invocation.
|
||||
|
||||
Logs the start of an agent invocation.
|
||||
No specific content payload is logged for this event, but standard metadata
|
||||
(agent name, session ID, invocation ID, user ID) is captured.
|
||||
"""
|
||||
await self._log({
|
||||
"event_type": "INVOCATION_STARTING",
|
||||
"agent": invocation_context.agent.name,
|
||||
@@ -644,7 +670,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def on_event_callback(
|
||||
self, *, invocation_context: InvocationContext, event: Event
|
||||
) -> None:
|
||||
"""Callback for agent events."""
|
||||
"""Callback for agent events.
|
||||
|
||||
Logs generic agent events including:
|
||||
1. Event type (determined from event properties)
|
||||
2. Event content (text, function calls, or responses)
|
||||
3. Error messages (if any)
|
||||
|
||||
The content is formatted based on the event type.
|
||||
If the content length exceeds `max_content_length`, it is truncated.
|
||||
"""
|
||||
content, truncated = self._format_content_safely(event.content)
|
||||
await self._log({
|
||||
"event_type": _get_event_type(event),
|
||||
@@ -661,7 +696,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def after_run_callback(
|
||||
self, *, invocation_context: InvocationContext
|
||||
) -> None:
|
||||
"""Callback after agent invocation."""
|
||||
"""Callback after agent invocation.
|
||||
|
||||
Logs the completion of an agent invocation.
|
||||
No specific content payload is logged for this event, but standard metadata
|
||||
(agent name, session ID, invocation ID, user ID) is captured.
|
||||
"""
|
||||
await self._log({
|
||||
"event_type": "INVOCATION_COMPLETED",
|
||||
"agent": invocation_context.agent.name,
|
||||
@@ -673,7 +713,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def before_agent_callback(
|
||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||
) -> None:
|
||||
"""Callback before an agent starts."""
|
||||
"""Callback before an agent starts.
|
||||
|
||||
Logs the start of a specific agent execution.
|
||||
Content includes:
|
||||
1. Agent Name (from callback context)
|
||||
"""
|
||||
await self._log({
|
||||
"event_type": "AGENT_STARTING",
|
||||
"agent": agent.name,
|
||||
@@ -686,7 +731,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def after_agent_callback(
|
||||
self, *, agent: BaseAgent, callback_context: CallbackContext
|
||||
) -> None:
|
||||
"""Callback after an agent completes."""
|
||||
"""Callback after an agent completes.
|
||||
|
||||
Logs the completion of a specific agent execution.
|
||||
Content includes:
|
||||
1. Agent Name (from callback context)
|
||||
"""
|
||||
await self._log({
|
||||
"event_type": "AGENT_COMPLETED",
|
||||
"agent": agent.name,
|
||||
@@ -699,11 +749,52 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def before_model_callback(
|
||||
self, *, callback_context: CallbackContext, llm_request: LlmRequest
|
||||
) -> None:
|
||||
"""Callback before LLM call."""
|
||||
"""Callback before LLM call.
|
||||
|
||||
Logs the LLM request details including:
|
||||
1. Model name
|
||||
2. Configuration parameters (temperature, top_p, top_k, max_output_tokens)
|
||||
3. Available tool names
|
||||
4. Prompt content (user/model messages)
|
||||
5. System instructions
|
||||
|
||||
The content is formatted as a single string with fields separated by ' | '.
|
||||
If the total length exceeds `max_content_length`, the string is truncated,
|
||||
prioritizing the metadata (Model, Params, Tools) over the Prompt and System
|
||||
Prompt.
|
||||
"""
|
||||
content_parts = [
|
||||
f"Model: {llm_request.model or 'default'}",
|
||||
]
|
||||
is_truncated = False
|
||||
|
||||
# 1. Params
|
||||
if llm_request.config:
|
||||
config = llm_request.config
|
||||
params_to_log = {}
|
||||
if hasattr(config, "temperature") and config.temperature is not None:
|
||||
params_to_log["temperature"] = config.temperature
|
||||
if hasattr(config, "top_p") and config.top_p is not None:
|
||||
params_to_log["top_p"] = config.top_p
|
||||
if hasattr(config, "top_k") and config.top_k is not None:
|
||||
params_to_log["top_k"] = config.top_k
|
||||
if (
|
||||
hasattr(config, "max_output_tokens")
|
||||
and config.max_output_tokens is not None
|
||||
):
|
||||
params_to_log["max_output_tokens"] = config.max_output_tokens
|
||||
|
||||
if params_to_log:
|
||||
params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()])
|
||||
content_parts.append(f"Params: {{{params_str}}}")
|
||||
|
||||
# 2. Tools
|
||||
if llm_request.tools_dict:
|
||||
content_parts.append(
|
||||
f"Available Tools: {list(llm_request.tools_dict.keys())}"
|
||||
)
|
||||
|
||||
# 3. Prompt
|
||||
if contents := getattr(llm_request, "contents", None):
|
||||
prompt_parts = []
|
||||
for c in contents:
|
||||
@@ -713,6 +804,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
is_truncated = True
|
||||
prompt_str = " | ".join(prompt_parts)
|
||||
content_parts.append(f"Prompt: {prompt_str}")
|
||||
|
||||
# 4. System Prompt
|
||||
system_instruction_text = "None"
|
||||
if llm_request.config and llm_request.config.system_instruction:
|
||||
si = llm_request.config.system_instruction
|
||||
@@ -736,29 +829,6 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
system_instruction_text = "Empty"
|
||||
|
||||
content_parts.append(f"System Prompt: {system_instruction_text}")
|
||||
if llm_request.config:
|
||||
config = llm_request.config
|
||||
params_to_log = {}
|
||||
if hasattr(config, "temperature") and config.temperature is not None:
|
||||
params_to_log["temperature"] = config.temperature
|
||||
if hasattr(config, "top_p") and config.top_p is not None:
|
||||
params_to_log["top_p"] = config.top_p
|
||||
if hasattr(config, "top_k") and config.top_k is not None:
|
||||
params_to_log["top_k"] = config.top_k
|
||||
if (
|
||||
hasattr(config, "max_output_tokens")
|
||||
and config.max_output_tokens is not None
|
||||
):
|
||||
params_to_log["max_output_tokens"] = config.max_output_tokens
|
||||
|
||||
if params_to_log:
|
||||
params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()])
|
||||
content_parts.append(f"Params: {{{params_str}}}")
|
||||
|
||||
if llm_request.tools_dict:
|
||||
content_parts.append(
|
||||
f"Available Tools: {list(llm_request.tools_dict.keys())}"
|
||||
)
|
||||
|
||||
final_content = " | ".join(content_parts)
|
||||
max_len = self._config.max_content_length
|
||||
@@ -778,7 +848,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
async def after_model_callback(
|
||||
self, *, callback_context: CallbackContext, llm_response: LlmResponse
|
||||
) -> None:
|
||||
"""Callback after LLM call."""
|
||||
"""Callback after LLM call.
|
||||
|
||||
Logs the LLM response details including:
|
||||
1. Tool calls (if any)
|
||||
2. Text response (if no tool calls)
|
||||
3. Token usage statistics (prompt, candidates, total)
|
||||
|
||||
The content is formatted as a single string with fields separated by ' | '.
|
||||
If the content length exceeds `max_content_length`, it is truncated.
|
||||
"""
|
||||
content_parts = []
|
||||
content = llm_response.content
|
||||
is_tool_call = False
|
||||
@@ -838,7 +917,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
tool_args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> None:
|
||||
"""Callback before tool call."""
|
||||
"""Callback before tool call.
|
||||
|
||||
Logs the tool execution start details including:
|
||||
1. Tool name
|
||||
2. Tool description
|
||||
3. Tool arguments
|
||||
|
||||
The content is formatted as 'Tool Name: ..., Description: ..., Arguments:
|
||||
...'.
|
||||
If the content length exceeds `max_content_length`, it is truncated.
|
||||
"""
|
||||
args_str, truncated = _format_args(
|
||||
tool_args, max_len=self._config.max_content_length
|
||||
)
|
||||
@@ -867,7 +956,15 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
tool_context: ToolContext,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Callback after tool call."""
|
||||
"""Callback after tool call.
|
||||
|
||||
Logs the tool execution result details including:
|
||||
1. Tool name
|
||||
2. Tool result
|
||||
|
||||
The content is formatted as 'Tool Name: ..., Result: ...'.
|
||||
If the content length exceeds `max_content_length`, it is truncated.
|
||||
"""
|
||||
result_str, truncated = _format_args(
|
||||
result, max_len=self._config.max_content_length
|
||||
)
|
||||
@@ -892,7 +989,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
llm_request: LlmRequest,
|
||||
error: Exception,
|
||||
) -> None:
|
||||
"""Callback for LLM errors."""
|
||||
"""Callback for model errors.
|
||||
|
||||
Logs errors that occur during LLM calls.
|
||||
No specific content payload is logged, but the error message is captured
|
||||
in the `error_message` field.
|
||||
"""
|
||||
await self._log({
|
||||
"event_type": "LLM_ERROR",
|
||||
"agent": callback_context.agent_name,
|
||||
@@ -910,7 +1012,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
tool_context: ToolContext,
|
||||
error: Exception,
|
||||
) -> None:
|
||||
"""Callback for tool errors."""
|
||||
"""Callback for tool errors.
|
||||
|
||||
Logs errors that occur during tool execution.
|
||||
Content includes:
|
||||
1. Tool name
|
||||
2. Tool arguments
|
||||
|
||||
The error message is captured in the `error_message` field.
|
||||
If the content length exceeds `max_content_length`, it is truncated.
|
||||
"""
|
||||
args_str, truncated = _format_args(
|
||||
tool_args, max_len=self._config.max_content_length
|
||||
)
|
||||
|
||||
@@ -695,9 +695,42 @@ class TestBigQueryAgentAnalyticsPlugin:
|
||||
user_message=types.Content(parts=[types.Part(text="Test")]),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
mock_log_error.assert_called_with("BQ Plugin: Write Error: Test BQ Error")
|
||||
mock_log_error.assert_called_with(
|
||||
"BQ Plugin: Write Error: %s", "Test BQ Error"
|
||||
)
|
||||
mock_write_client.append_rows.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_mismatch_error_handling(
|
||||
self, bq_plugin_inst, mock_write_client, invocation_context
|
||||
):
|
||||
async def fake_append_rows_with_schema_error(requests, **kwargs):
|
||||
mock_resp = mock.MagicMock()
|
||||
mock_resp.row_errors = []
|
||||
mock_resp.error = mock.MagicMock()
|
||||
mock_resp.error.code = 3
|
||||
mock_resp.error.message = (
|
||||
"Schema mismatch: Field 'new_field' not found in table."
|
||||
)
|
||||
return _async_gen(mock_resp)
|
||||
|
||||
mock_write_client.append_rows.side_effect = (
|
||||
fake_append_rows_with_schema_error
|
||||
)
|
||||
|
||||
with mock.patch.object(logging, "error") as mock_log_error:
|
||||
await bq_plugin_inst.on_user_message_callback(
|
||||
invocation_context=invocation_context,
|
||||
user_message=types.Content(parts=[types.Part(text="Test")]),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
mock_log_error.assert_called_with(
|
||||
"BQ Plugin: Schema Mismatch Error. The BigQuery table schema may be"
|
||||
" incorrect or out of sync with the plugin. Please verify the table"
|
||||
" definition. Details: %s",
|
||||
"Schema mismatch: Field 'new_field' not found in table.",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client):
|
||||
await bq_plugin_inst.close()
|
||||
@@ -801,6 +834,42 @@ class TestBigQueryAgentAnalyticsPlugin:
|
||||
" Empty"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_before_model_callback_with_params_and_tools(
|
||||
self,
|
||||
bq_plugin_inst,
|
||||
mock_write_client,
|
||||
callback_context,
|
||||
dummy_arrow_schema,
|
||||
):
|
||||
llm_request = llm_request_lib.LlmRequest(
|
||||
model="gemini-pro",
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
system_instruction=types.Content(parts=[types.Part(text="Sys")]),
|
||||
),
|
||||
contents=[types.Content(role="user", parts=[types.Part(text="User")])],
|
||||
)
|
||||
# Manually set tools_dict as it is excluded from init
|
||||
llm_request.tools_dict = {"tool1": "func1", "tool2": "func2"}
|
||||
|
||||
await bq_plugin_inst.before_model_callback(
|
||||
callback_context=callback_context, llm_request=llm_request
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
|
||||
_assert_common_fields(log_entry, "LLM_REQUEST")
|
||||
# Order: Model | Params | Tools | Prompt | System Prompt
|
||||
# Note: Params order depends on dict iteration but here we construct it deterministically in code?
|
||||
# The code does: params_to_log["temperature"] = ... then "top_p" = ...
|
||||
# So order should be temperature, top_p.
|
||||
assert "Model: gemini-pro" in log_entry["content"]
|
||||
assert "Params: {temperature=0.5, top_p=0.9}" in log_entry["content"]
|
||||
assert "Available Tools: ['tool1', 'tool2']" in log_entry["content"]
|
||||
assert "Prompt: user: text: 'User'" in log_entry["content"]
|
||||
assert "System Prompt: Sys" in log_entry["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_model_callback_text_response(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user