From 5ac5129fb01913516d6f5348a825ca83d024d33a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 17 Nov 2025 15:00:09 -0800 Subject: [PATCH] feat: Enhance BQ Plugin Schema, Error Handling, and Logging This update enhances the BigQuery agent analytics plugin: * **Enhanced Error Logging:** Improved error messages for schema mismatches. * **Reordered Logging Content:** Prioritized metadata in `before_model_callback`. PiperOrigin-RevId: 833508755 --- .../bigquery_agent_analytics_plugin.py | 191 ++++++++++++++---- .../test_bigquery_agent_analytics_plugin.py | 71 ++++++- 2 files changed, 221 insertions(+), 41 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index ad189bbd..63b95e57 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -437,7 +437,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): return self._config.content_formatter(content), False return _format_content(content, max_len=self._config.max_content_length) except Exception as e: - logging.warning(f"Content formatter failed: {e}") + logging.warning("Content formatter failed: %s", e) return "[FORMATTING FAILED]", False async def _ensure_init(self): @@ -523,7 +523,21 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self._write_client.append_rows(iter([req])) ): if resp.error.code != 0: - logging.error(f"BQ Plugin: Write Error: {resp.error.message}") + msg = resp.error.message + # Check for common schema mismatch indicators + if ( + "schema mismatch" in msg.lower() + or "field" in msg.lower() + or "type" in msg.lower() + ): + logging.error( + "BQ Plugin: Schema Mismatch Error. The BigQuery table schema" + " may be incorrect or out of sync with the plugin. Please" + " verify the table definition. Details: %s", + msg, + ) + else: + logging.error("BQ Plugin: Write Error: %s", msg) except RuntimeError as e: if "Event loop is closed" not in str(e) and not self._is_shutting_down: @@ -578,7 +592,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): if self._background_tasks: logging.info( - f"BQ Plugin: Flushing {len(self._background_tasks)} pending logs..." + "BQ Plugin: Flushing %s pending logs...", len(self._background_tasks) ) try: await asyncio.wait( @@ -598,12 +612,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): timeout=self._config.client_close_timeout, ) except Exception as e: - logging.warning(f"BQ Plugin: Error closing write client: {e}") + logging.warning("BQ Plugin: Error closing write client: %s", e) if self._bq_client: try: self._bq_client.close() except Exception as e: - logging.warning(f"BQ Plugin: Error closing BQ client: {e}") + logging.warning("BQ Plugin: Error closing BQ client: %s", e) self._write_client = None self._bq_client = None @@ -617,7 +631,14 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): invocation_context: InvocationContext, user_message: types.Content, ) -> None: - """Callback for user messages.""" + """Callback for user messages. + + Logs the user message details including: + 1. User content (text) + + The content is formatted as 'User Content: {content}'. + If the content length exceeds `max_content_length`, it is truncated. + """ content, truncated = self._format_content_safely(user_message) await self._log({ "event_type": "USER_MESSAGE_RECEIVED", @@ -632,7 +653,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def before_run_callback( self, *, invocation_context: InvocationContext ) -> None: - """Callback before agent invocation.""" + """Callback before agent invocation. + + Logs the start of an agent invocation. + No specific content payload is logged for this event, but standard metadata + (agent name, session ID, invocation ID, user ID) is captured. + """ await self._log({ "event_type": "INVOCATION_STARTING", "agent": invocation_context.agent.name, @@ -644,7 +670,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def on_event_callback( self, *, invocation_context: InvocationContext, event: Event ) -> None: - """Callback for agent events.""" + """Callback for agent events. + + Logs generic agent events including: + 1. Event type (determined from event properties) + 2. Event content (text, function calls, or responses) + 3. Error messages (if any) + + The content is formatted based on the event type. + If the content length exceeds `max_content_length`, it is truncated. + """ content, truncated = self._format_content_safely(event.content) await self._log({ "event_type": _get_event_type(event), @@ -661,7 +696,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def after_run_callback( self, *, invocation_context: InvocationContext ) -> None: - """Callback after agent invocation.""" + """Callback after agent invocation. + + Logs the completion of an agent invocation. + No specific content payload is logged for this event, but standard metadata + (agent name, session ID, invocation ID, user ID) is captured. + """ await self._log({ "event_type": "INVOCATION_COMPLETED", "agent": invocation_context.agent.name, @@ -673,7 +713,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def before_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> None: - """Callback before an agent starts.""" + """Callback before an agent starts. + + Logs the start of a specific agent execution. + Content includes: + 1. Agent Name (from callback context) + """ await self._log({ "event_type": "AGENT_STARTING", "agent": agent.name, @@ -686,7 +731,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> None: - """Callback after an agent completes.""" + """Callback after an agent completes. + + Logs the completion of a specific agent execution. + Content includes: + 1. Agent Name (from callback context) + """ await self._log({ "event_type": "AGENT_COMPLETED", "agent": agent.name, @@ -699,11 +749,52 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> None: - """Callback before LLM call.""" + """Callback before LLM call. + + Logs the LLM request details including: + 1. Model name + 2. Configuration parameters (temperature, top_p, top_k, max_output_tokens) + 3. Available tool names + 4. Prompt content (user/model messages) + 5. System instructions + + The content is formatted as a single string with fields separated by ' | '. + If the total length exceeds `max_content_length`, the string is truncated, + prioritizing the metadata (Model, Params, Tools) over the Prompt and System + Prompt. + """ content_parts = [ f"Model: {llm_request.model or 'default'}", ] is_truncated = False + + # 1. Params + if llm_request.config: + config = llm_request.config + params_to_log = {} + if hasattr(config, "temperature") and config.temperature is not None: + params_to_log["temperature"] = config.temperature + if hasattr(config, "top_p") and config.top_p is not None: + params_to_log["top_p"] = config.top_p + if hasattr(config, "top_k") and config.top_k is not None: + params_to_log["top_k"] = config.top_k + if ( + hasattr(config, "max_output_tokens") + and config.max_output_tokens is not None + ): + params_to_log["max_output_tokens"] = config.max_output_tokens + + if params_to_log: + params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()]) + content_parts.append(f"Params: {{{params_str}}}") + + # 2. Tools + if llm_request.tools_dict: + content_parts.append( + f"Available Tools: {list(llm_request.tools_dict.keys())}" + ) + + # 3. Prompt if contents := getattr(llm_request, "contents", None): prompt_parts = [] for c in contents: @@ -713,6 +804,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): is_truncated = True prompt_str = " | ".join(prompt_parts) content_parts.append(f"Prompt: {prompt_str}") + + # 4. System Prompt system_instruction_text = "None" if llm_request.config and llm_request.config.system_instruction: si = llm_request.config.system_instruction @@ -736,29 +829,6 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): system_instruction_text = "Empty" content_parts.append(f"System Prompt: {system_instruction_text}") - if llm_request.config: - config = llm_request.config - params_to_log = {} - if hasattr(config, "temperature") and config.temperature is not None: - params_to_log["temperature"] = config.temperature - if hasattr(config, "top_p") and config.top_p is not None: - params_to_log["top_p"] = config.top_p - if hasattr(config, "top_k") and config.top_k is not None: - params_to_log["top_k"] = config.top_k - if ( - hasattr(config, "max_output_tokens") - and config.max_output_tokens is not None - ): - params_to_log["max_output_tokens"] = config.max_output_tokens - - if params_to_log: - params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()]) - content_parts.append(f"Params: {{{params_str}}}") - - if llm_request.tools_dict: - content_parts.append( - f"Available Tools: {list(llm_request.tools_dict.keys())}" - ) final_content = " | ".join(content_parts) max_len = self._config.max_content_length @@ -778,7 +848,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse ) -> None: - """Callback after LLM call.""" + """Callback after LLM call. + + Logs the LLM response details including: + 1. Tool calls (if any) + 2. Text response (if no tool calls) + 3. Token usage statistics (prompt, candidates, total) + + The content is formatted as a single string with fields separated by ' | '. + If the content length exceeds `max_content_length`, it is truncated. + """ content_parts = [] content = llm_response.content is_tool_call = False @@ -838,7 +917,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): tool_args: dict[str, Any], tool_context: ToolContext, ) -> None: - """Callback before tool call.""" + """Callback before tool call. + + Logs the tool execution start details including: + 1. Tool name + 2. Tool description + 3. Tool arguments + + The content is formatted as 'Tool Name: ..., Description: ..., Arguments: + ...'. + If the content length exceeds `max_content_length`, it is truncated. + """ args_str, truncated = _format_args( tool_args, max_len=self._config.max_content_length ) @@ -867,7 +956,15 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): tool_context: ToolContext, result: dict[str, Any], ) -> None: - """Callback after tool call.""" + """Callback after tool call. + + Logs the tool execution result details including: + 1. Tool name + 2. Tool result + + The content is formatted as 'Tool Name: ..., Result: ...'. + If the content length exceeds `max_content_length`, it is truncated. + """ result_str, truncated = _format_args( result, max_len=self._config.max_content_length ) @@ -892,7 +989,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): llm_request: LlmRequest, error: Exception, ) -> None: - """Callback for LLM errors.""" + """Callback for model errors. + + Logs errors that occur during LLM calls. + No specific content payload is logged, but the error message is captured + in the `error_message` field. + """ await self._log({ "event_type": "LLM_ERROR", "agent": callback_context.agent_name, @@ -910,7 +1012,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): tool_context: ToolContext, error: Exception, ) -> None: - """Callback for tool errors.""" + """Callback for tool errors. + + Logs errors that occur during tool execution. + Content includes: + 1. Tool name + 2. Tool arguments + + The error message is captured in the `error_message` field. + If the content length exceeds `max_content_length`, it is truncated. + """ args_str, truncated = _format_args( tool_args, max_len=self._config.max_content_length ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index abbd24b9..6f0412db 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -695,9 +695,42 @@ class TestBigQueryAgentAnalyticsPlugin: user_message=types.Content(parts=[types.Part(text="Test")]), ) await asyncio.sleep(0.01) - mock_log_error.assert_called_with("BQ Plugin: Write Error: Test BQ Error") + mock_log_error.assert_called_with( + "BQ Plugin: Write Error: %s", "Test BQ Error" + ) mock_write_client.append_rows.assert_called_once() + @pytest.mark.asyncio + async def test_schema_mismatch_error_handling( + self, bq_plugin_inst, mock_write_client, invocation_context + ): + async def fake_append_rows_with_schema_error(requests, **kwargs): + mock_resp = mock.MagicMock() + mock_resp.row_errors = [] + mock_resp.error = mock.MagicMock() + mock_resp.error.code = 3 + mock_resp.error.message = ( + "Schema mismatch: Field 'new_field' not found in table." + ) + return _async_gen(mock_resp) + + mock_write_client.append_rows.side_effect = ( + fake_append_rows_with_schema_error + ) + + with mock.patch.object(logging, "error") as mock_log_error: + await bq_plugin_inst.on_user_message_callback( + invocation_context=invocation_context, + user_message=types.Content(parts=[types.Part(text="Test")]), + ) + await asyncio.sleep(0.01) + mock_log_error.assert_called_with( + "BQ Plugin: Schema Mismatch Error. The BigQuery table schema may be" + " incorrect or out of sync with the plugin. Please verify the table" + " definition. Details: %s", + "Schema mismatch: Field 'new_field' not found in table.", + ) + @pytest.mark.asyncio async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client): await bq_plugin_inst.close() @@ -801,6 +834,42 @@ class TestBigQueryAgentAnalyticsPlugin: " Empty" ) + @pytest.mark.asyncio + async def test_before_model_callback_with_params_and_tools( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + temperature=0.5, + top_p=0.9, + system_instruction=types.Content(parts=[types.Part(text="Sys")]), + ), + contents=[types.Content(role="user", parts=[types.Part(text="User")])], + ) + # Manually set tools_dict as it is excluded from init + llm_request.tools_dict = {"tool1": "func1", "tool2": "func2"} + + await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await asyncio.sleep(0.01) + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_REQUEST") + # Order: Model | Params | Tools | Prompt | System Prompt + # Note: Params order depends on dict iteration but here we construct it deterministically in code? + # The code does: params_to_log["temperature"] = ... then "top_p" = ... + # So order should be temperature, top_p. + assert "Model: gemini-pro" in log_entry["content"] + assert "Params: {temperature=0.5, top_p=0.9}" in log_entry["content"] + assert "Available Tools: ['tool1', 'tool2']" in log_entry["content"] + assert "Prompt: user: text: 'User'" in log_entry["content"] + assert "System Prompt: Sys" in log_entry["content"] + @pytest.mark.asyncio async def test_after_model_callback_text_response( self,