From 37ee1869b57f6cdb37be2bb514aac6316a910a46 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 11 Nov 2025 22:46:39 -0800 Subject: [PATCH] fix: Enhance BigQuery Plugin Robustness and Schema Accuracy This update improves the `BigQueryAgentAnalyticsPlugin` in several ways: * Corrects the PyArrow schema generation to accurately reflect BigQuery field nullability based on the `mode` attribute. * Introduces a configurable `shutdown_timeout` in `BigQueryLoggerConfig` to manage how long the plugin waits for pending logs to flush during shutdown. * Adds more robust error handling within the `shutdown` method and background write tasks, particularly for event loop closure issues. * Improves internal logging to provide better diagnostics. * Ensures consistent use of safe content formatting. PiperOrigin-RevId: 831225837 --- .../bigquery_agent_analytics_plugin.py | 130 +++++++++++------ .../test_bigquery_agent_analytics_plugin.py | 138 ++++++++++++++++-- 2 files changed, 204 insertions(+), 64 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 65fd5398..46fd46fb 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -23,7 +23,6 @@ from typing import Any from typing import Callable from typing import List from typing import Optional -from typing import Set from typing import TYPE_CHECKING from google.api_core.gapic_v1 import client_info as gapic_client_info @@ -173,10 +172,11 @@ def _bq_to_arrow_field(bq_field): metadata = _BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA.get( bq_field.field_type.upper() if bq_field.field_type else "" ) + nullable = bq_field.mode.upper() != "REQUIRED" return pa.field( bq_field.name, arrow_type, - nullable=(bq_field.mode != "REPEATED"), + nullable=nullable, metadata=metadata, ) logging.warning( @@ -213,12 +213,18 @@ class BigQueryLoggerConfig: event_denylist: A list of event types to skip logging. content_formatter: An optional function to format event content before logging. + shutdown_timeout: Seconds to wait for logs to flush during shutdown. + client_close_timeout: Seconds to wait for BQ client to close. + max_content_length: The maximum length of content parts before truncation. """ enabled: bool = True event_allowlist: Optional[List[str]] = None event_denylist: Optional[List[str]] = None content_formatter: Optional[Callable[[Any], str]] = None + shutdown_timeout: float = 5.0 + client_close_timeout: float = 2.0 + max_content_length: int = 500 # --- Helper Formatters --- @@ -313,16 +319,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self._write_client: BigQueryWriteAsyncClient | None = None self._init_lock: asyncio.Lock | None = None self._arrow_schema: pa.Schema | None = None - self._background_tasks: Set[asyncio.Task] = set() # Track pending logs + self._background_tasks: set[asyncio.Task] = set() + self._is_shutting_down = False self._schema = [ - bigquery.SchemaField("timestamp", "TIMESTAMP"), - bigquery.SchemaField("event_type", "STRING"), - bigquery.SchemaField("agent", "STRING"), - bigquery.SchemaField("session_id", "STRING"), - bigquery.SchemaField("invocation_id", "STRING"), - bigquery.SchemaField("user_id", "STRING"), - bigquery.SchemaField("content", "STRING"), - bigquery.SchemaField("error_message", "STRING"), + bigquery.SchemaField("timestamp", "TIMESTAMP", mode="REQUIRED"), + bigquery.SchemaField("event_type", "STRING", mode="NULLABLE"), + bigquery.SchemaField("agent", "STRING", mode="NULLABLE"), + bigquery.SchemaField("session_id", "STRING", mode="NULLABLE"), + bigquery.SchemaField("invocation_id", "STRING", mode="NULLABLE"), + bigquery.SchemaField("user_id", "STRING", mode="NULLABLE"), + bigquery.SchemaField("content", "STRING", mode="NULLABLE"), + bigquery.SchemaField("error_message", "STRING", mode="NULLABLE"), ] def _format_content_safely( @@ -334,7 +341,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): try: if self._config.content_formatter: return self._config.content_formatter(content) - return _format_content(content) + return _format_content(content, max_len=self._config.max_content_length) except Exception as e: logging.warning(f"Content formatter failed: {e}") return "[FORMATTING FAILED]" @@ -363,14 +370,17 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): # Ensure table exists (sync call in thread) def create_resources(): if self._bq_client: - dataset = self._bq_client.create_dataset( - self._dataset_id, exists_ok=True - ) + self._bq_client.create_dataset(self._dataset_id, exists_ok=True) table = bigquery.Table( f"{self._project_id}.{self._dataset_id}.{self._table_id}", schema=self._schema, ) self._bq_client.create_table(table, exists_ok=True) + logging.info( + "BQ Plugin: Dataset %s and Table %s ensured to exist.", + self._dataset_id, + self._table_id, + ) await asyncio.to_thread(create_resources) @@ -379,9 +389,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): client_info=client_info, ) self._arrow_schema = to_arrow_schema(self._schema) + if not self._arrow_schema: + raise RuntimeError("Failed to convert BigQuery schema to Arrow.") + logging.info("BQ Plugin: Initialized successfully.") return True except Exception as e: - logging.error(f"BQ Init Failed: {e}") + logging.error("BQ Plugin: Init Failed:", exc_info=True) return False async def _perform_write(self, row: dict): @@ -412,14 +425,16 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): self._write_client.append_rows(iter([req])) ): if resp.error.code != 0: - logging.error(f"BQ Write Error: {resp.error.message}") + logging.error(f"BQ Plugin: Write Error: {resp.error.message}") except RuntimeError as e: - # Silently ignore event loop closed errors during background writes - if "Event loop is closed" not in str(e): - logging.exception(f"BQ Runtime Error: {e}") + if "Event loop is closed" not in str(e) and not self._is_shutting_down: + logging.error("BQ Plugin: Runtime Error during write:", exc_info=True) + except asyncio.CancelledError: + if not self._is_shutting_down: + logging.warning("BQ Plugin: Write task cancelled unexpectedly.") except Exception as e: - logging.error(f"BQ Write Failed: {e}") + logging.error("BQ Plugin: Write Failed:", exc_info=True) async def _log(self, data: dict): """Schedules a log entry to be written in the background.""" @@ -457,32 +472,44 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): async def close(self): """Flushes pending logs and closes client.""" - # 1. Wait for pending background logs (best effort, 2s timeout) - if self._background_tasks: - logging.info(f"Flushing {len(self._background_tasks)} pending BQ logs...") - done, pending = await asyncio.wait(self._background_tasks, timeout=2.0) - if pending: - logging.warning( - f"{len(pending)} BQ logs could not be flushed before shutdown." - ) + if self._is_shutting_down: + return + self._is_shutting_down = True + logging.info("BQ Plugin: Shutdown started.") - # 2. Close client - if self._write_client and self._write_client.transport: + if self._background_tasks: + logging.info( + f"BQ Plugin: Flushing {len(self._background_tasks)} pending logs..." + ) try: - logging.info("Closing BQ Write client transport...") + await asyncio.wait( + self._background_tasks, timeout=self._config.shutdown_timeout + ) + except asyncio.TimeoutError: + logging.warning("BQ Plugin: Timeout waiting for logs to flush.") + except Exception as e: + logging.warning("BQ Plugin: Error flushing logs:", exc_info=True) + + # Use getattr for safe access in case transport is not present. + if self._write_client and getattr(self._write_client, "transport", None): + try: + logging.info("BQ Plugin: Closing write client.") await asyncio.wait_for( - self._write_client.transport.close(), timeout=1.0 + self._write_client.transport.close(), + timeout=self._config.client_close_timeout, ) except Exception as e: - logging.warning(f"Error during BQ Write client transport close: {e}") - self._write_client = None + logging.warning(f"BQ Plugin: Error closing write client: {e}") if self._bq_client: try: - logging.info("Closing BQ client...") self._bq_client.close() except Exception as e: - logging.warning(f"Error during BQ client close: {e}") - self._bq_client = None + logging.warning(f"BQ Plugin: Error closing BQ client: {e}") + + self._write_client = None + self._bq_client = None + self._is_shutting_down = False + logging.info("BQ Plugin: Shutdown complete.") # --- Streamlined Callbacks --- async def on_user_message_callback( @@ -523,13 +550,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): "session_id": invocation_context.session.id, "invocation_id": invocation_context.invocation_id, "user_id": invocation_context.session.user_id, - "content": ( - json.dumps( - [part.model_dump(mode="json") for part in event.content.parts] - ) - if event.content and event.content.parts - else None - ), + "content": self._format_content_safely(event.content), "error_message": event.error_message, "timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc), }) @@ -579,6 +600,11 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): content_parts = [ f"Model: {llm_request.model or 'default'}", ] + if contents := getattr(llm_request, "contents", None): + prompt_str = " | ".join( + [f"{c.role}: {self._format_content_safely(c)}" for c in contents] + ) + content_parts.append(f"Prompt: {prompt_str}") system_instruction_text = "None" if llm_request.config and llm_request.config.system_instruction: si = llm_request.config.system_instruction @@ -627,6 +653,9 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): ) final_content = " | ".join(content_parts) + max_len = self._config.max_content_length + if len(final_content) > max_len: + final_content = final_content[:max_len] + "..." await self._log({ "event_type": "LLM_REQUEST", "agent": callback_context.agent_name, @@ -702,7 +731,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): "user_id": tool_context.session.user_id, "content": ( f"Tool Name: {tool.name}, Description: {tool.description}," - f" Arguments: {_format_args(tool_args)}" + " Arguments:" + f" {_format_args(tool_args, max_len=self._config.max_content_length)}" ), }) @@ -721,7 +751,10 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): "session_id": tool_context.session.id, "invocation_id": tool_context.invocation_id, "user_id": tool_context.session.user_id, - "content": f"Tool Name: {tool.name}, Result: {_format_args(result)}", + "content": ( + f"Tool Name: {tool.name}, Result:" + f" {_format_args(result, max_len=self._config.max_content_length)}" + ), }) async def on_model_error_callback( @@ -757,7 +790,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin): "invocation_id": tool_context.invocation_id, "user_id": tool_context.session.user_id, "content": ( - f"Tool Name: {tool.name}, Arguments: {_format_args(tool_args)}" + f"Tool Name: {tool.name}, Arguments:" + f" {_format_args(tool_args, max_len=self._config.max_content_length)}" ), "error_message": str(error), }) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 8696d744..20217fdb 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -146,14 +146,14 @@ def mock_write_client(): @pytest.fixture def dummy_arrow_schema(): return pa.schema([ - pa.field("timestamp", pa.timestamp("us", tz="UTC")), - pa.field("event_type", pa.string()), - pa.field("agent", pa.string()), - pa.field("session_id", pa.string()), - pa.field("invocation_id", pa.string()), - pa.field("user_id", pa.string()), - pa.field("content", pa.string()), - pa.field("error_message", pa.string()), + pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False), + pa.field("event_type", pa.string(), nullable=True), + pa.field("agent", pa.string(), nullable=True), + pa.field("session_id", pa.string(), nullable=True), + pa.field("invocation_id", pa.string(), nullable=True), + pa.field("user_id", pa.string(), nullable=True), + pa.field("content", pa.string(), nullable=True), + pa.field("error_message", pa.string(), nullable=True), ]) @@ -391,6 +391,102 @@ class TestBigQueryAgentAnalyticsPlugin: log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) assert log_entry["content"] == "User Content: [FORMATTING FAILED]" + @pytest.mark.asyncio + async def test_max_content_length( + self, + mock_write_client, + invocation_context, + callback_context, + mock_auth_default, + mock_bq_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + config = BigQueryLoggerConfig(max_content_length=40) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + await plugin._ensure_init() + mock_write_client.append_rows.reset_mock() + + # Test User Message Truncation + user_message = types.Content( + parts=[types.Part(text="12345678901234567890123456789012345678901")] + ) # 41 chars + await plugin.on_user_message_callback( + invocation_context=invocation_context, user_message=user_message + ) + await asyncio.sleep(0.01) + mock_write_client.append_rows.assert_called_once() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + assert ( + log_entry["content"] + == "User Content: text: '1234567890123456789012345678901234567890...' " + ) + mock_write_client.append_rows.reset_mock() + + # Test before_model_callback full content truncation + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + system_instruction=types.Content( + parts=[types.Part(text="System Instruction")] + ) + ), + contents=[ + types.Content(role="user", parts=[types.Part(text="Prompt")]) + ], + ) + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await asyncio.sleep(0.01) + mock_write_client.append_rows.assert_called_once() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + # Full content: "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt: System Instruction" + # Truncated to 40 chars + ...: + expected_content = "Model: gemini-pro | Prompt: user: text: ..." + assert log_entry["content"] == expected_content + + @pytest.mark.asyncio + async def test_max_content_length_tool_args( + self, + mock_write_client, + tool_context, + mock_auth_default, + mock_bq_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + config = BigQueryLoggerConfig(max_content_length=10) + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + await plugin._ensure_init() + mock_write_client.append_rows.reset_mock() + + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + + # Args length > 10 + # {"param": "long_value"} is ~24 chars + await plugin.before_tool_callback( + tool=mock_tool, + tool_args={"param": "long_value"}, + tool_context=tool_context, + ) + await asyncio.sleep(0.01) + mock_write_client.append_rows.assert_called_once() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + # JSON string: '{"param": "long_value"}' + # Truncated to 10: '{"param": ...' + assert 'Arguments: {"param": ...' in log_entry["content"] + @pytest.mark.asyncio async def test_on_user_message_callback_logs_correctly( self, @@ -430,7 +526,7 @@ class TestBigQueryAgentAnalyticsPlugin: await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_CALL", agent="MyTestAgent") - assert '"name": "get_weather"' in log_entry["content"] + assert "call: get_weather" in log_entry["content"] assert log_entry["timestamp"] == datetime.datetime( 2025, 10, 22, 10, 0, 0, tzinfo=datetime.timezone.utc ) @@ -456,7 +552,7 @@ class TestBigQueryAgentAnalyticsPlugin: await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent") - assert '"text": "Hello there!"' in log_entry["content"] + assert "text: 'Hello there!'" in log_entry["content"] assert log_entry["timestamp"] == datetime.datetime( 2025, 10, 22, 11, 0, 0, tzinfo=datetime.timezone.utc ) @@ -485,7 +581,7 @@ class TestBigQueryAgentAnalyticsPlugin: user_message=types.Content(parts=[types.Part(text="Test")]), ) await asyncio.sleep(0.01) - mock_log_error.assert_any_call("BQ Init Failed: Auth failed") + mock_log_error.assert_any_call("BQ Plugin: Init Failed:", exc_info=True) mock_write_client.append_rows.assert_not_called() @pytest.mark.asyncio @@ -509,16 +605,20 @@ class TestBigQueryAgentAnalyticsPlugin: user_message=types.Content(parts=[types.Part(text="Test")]), ) await asyncio.sleep(0.01) - mock_log_error.assert_called_with("BQ Write Error: Test BQ Error") + mock_log_error.assert_called_with("BQ Plugin: Write Error: Test BQ Error") mock_write_client.append_rows.assert_called_once() @pytest.mark.asyncio async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client): await bq_plugin_inst.close() mock_write_client.transport.close.assert_called_once() - mock_bq_client.close.assert_called_once() + # bq_client might not be closed if it wasn't created or if close() failed, + # but here it should be. + # in the new implementation we verify attributes are reset + assert bq_plugin_inst._write_client is None + assert bq_plugin_inst._bq_client is None + assert bq_plugin_inst._is_shutting_down is False - # ... other tests remain the same ... @pytest.mark.asyncio async def test_before_run_callback_logs_correctly( self, @@ -595,7 +695,9 @@ class TestBigQueryAgentAnalyticsPlugin: ): llm_request = llm_request_lib.LlmRequest( model="gemini-pro", - contents=[types.Content(parts=[types.Part(text="Prompt")])], + contents=[ + types.Content(role="user", parts=[types.Part(text="Prompt")]) + ], ) await bq_plugin_inst.before_model_callback( callback_context=callback_context, llm_request=llm_request @@ -603,7 +705,11 @@ class TestBigQueryAgentAnalyticsPlugin: await asyncio.sleep(0.01) log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "LLM_REQUEST") - assert log_entry["content"] == "Model: gemini-pro | System Prompt: Empty" + assert ( + log_entry["content"] + == "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt:" + " Empty" + ) @pytest.mark.asyncio async def test_after_model_callback_text_response(