feat: Enhance TraceManager async safety, enrich BigQuery plugin logging, and fix serialization

*   **Async Safety:** Improved TraceManager context variable handling to ensure correct context isolation in concurrent asynchronous operations. This was achieved by using immutable tuples for the span stack and making copies of context dictionaries before modification.
*   **Enhanced Logging:** The BigQueryAgentAnalyticsPlugin now captures richer metadata, including:
    *   Root agent name (via a new context variable).
    *   LLM model name and version.
    *   Usage metadata from LLM requests and responses.
*   **Serialization Fix:** Updated BigQueryAgentAnalyticsPlugin to prevent JSON serialization errors when logging custom objects (e.g., Dataclasses). These are now automatically converted to dictionaries or string representations to ensure successful insertion into BigQuery.

PiperOrigin-RevId: 855415320
This commit is contained in:
Google Team Member
2026-01-12 15:45:42 -08:00
committed by Copybara-Service
parent 2592f01eb6
commit a4116a6cbf
2 changed files with 338 additions and 38 deletions
@@ -18,6 +18,7 @@ import asyncio
import atexit
from concurrent.futures import ThreadPoolExecutor
import contextvars
import dataclasses
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
@@ -120,6 +121,8 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]:
return obj, False
elif isinstance(obj, dict):
truncated_any = False
# Use dict comprehension for potentially slightly better performance,
# but explicit loop is fine for clarity given recursive nature.
new_dict = {}
for k, v in obj.items():
val, trunc = _recursive_smart_truncate(v, max_len)
@@ -130,13 +133,41 @@ def _recursive_smart_truncate(obj: Any, max_len: int) -> tuple[Any, bool]:
elif isinstance(obj, (list, tuple)):
truncated_any = False
new_list = []
# Explicit loop to handle flag propagation
for i in obj:
val, trunc = _recursive_smart_truncate(i, max_len)
if trunc:
truncated_any = True
new_list.append(val)
return type(obj)(new_list), truncated_any
return obj, False
elif dataclasses.is_dataclass(obj) and not isinstance(obj, type):
# Convert dataclasses to dicts so they become valid JSON objects
return _recursive_smart_truncate(dataclasses.asdict(obj), max_len)
elif hasattr(obj, "model_dump") and callable(obj.model_dump):
# Pydantic v2
try:
return _recursive_smart_truncate(obj.model_dump(), max_len)
except Exception:
pass
elif hasattr(obj, "dict") and callable(obj.dict):
# Pydantic v1
try:
return _recursive_smart_truncate(obj.dict(), max_len)
except Exception:
pass
elif hasattr(obj, "to_dict") and callable(obj.to_dict):
# Common pattern for custom objects
try:
return _recursive_smart_truncate(obj.to_dict(), max_len)
except Exception:
pass
elif obj is None or isinstance(obj, (int, float, bool)):
# Basic types are safe
return obj, False
# Fallback for unknown types: Convert to string to ensure JSON validity
# We return string representation of the object, which is a valid JSON string value.
return str(obj), False
# --- PyArrow Helper Functions ---
@@ -352,9 +383,10 @@ class BigQueryLoggerConfig:
# ==============================================================================
_trace_id_ctx = contextvars.ContextVar("_bq_analytics_trace_id", default=None)
_span_stack_ctx = contextvars.ContextVar(
"_bq_analytics_span_stack", default=None
_root_agent_name_ctx = contextvars.ContextVar(
"_bq_analytics_root_agent_name", default=None
)
_span_stack_ctx = contextvars.ContextVar("_bq_analytics_span_stack", default=())
_span_times_ctx = contextvars.ContextVar(
"_bq_analytics_span_times", default=None
)
@@ -370,7 +402,13 @@ class TraceManager:
def init_trace(callback_context: CallbackContext) -> None:
if _trace_id_ctx.get() is None:
_trace_id_ctx.set(callback_context.invocation_id)
_span_stack_ctx.set([])
# Extract root agent name from invocation context
try:
root_agent = callback_context._invocation_context.agent.root_agent
_root_agent_name_ctx.set(root_agent.name)
except (AttributeError, ValueError):
pass
_span_stack_ctx.set(())
_span_times_ctx.set({})
_span_first_token_times_ctx.set({})
@@ -393,39 +431,29 @@ class TraceManager:
span_id = span_id or str(uuid.uuid4())
stack = _span_stack_ctx.get()
if stack is None:
# Should have been called by init_trace, but just in case
stack = []
_span_stack_ctx.set(stack)
stack.append(span_id)
times = _span_times_ctx.get()
if times is None:
times = {}
_span_times_ctx.set(times)
first_tokens = _span_first_token_times_ctx.get()
if first_tokens is None:
first_tokens = {}
_span_first_token_times_ctx.set(first_tokens)
new_stack = stack + (span_id,)
_span_stack_ctx.set(new_stack)
times = dict(_span_times_ctx.get() or {})
times[span_id] = time.time()
_span_times_ctx.set(times)
return span_id
@staticmethod
def pop_span() -> tuple[Optional[str], Optional[int]]:
stack = _span_stack_ctx.get()
stack = list(_span_stack_ctx.get())
if not stack:
return None, None
span_id = stack.pop()
_span_stack_ctx.set(tuple(stack))
times = _span_times_ctx.get()
start_time = times.pop(span_id, None) if times else None
times_dict = dict(_span_times_ctx.get() or {})
start_time = times_dict.pop(span_id, None)
_span_times_ctx.set(times_dict)
first_tokens = _span_first_token_times_ctx.get()
if first_tokens:
first_tokens.pop(span_id, None)
ft_dict = dict(_span_first_token_times_ctx.get() or {})
ft_dict.pop(span_id, None)
_span_first_token_times_ctx.set(ft_dict)
duration_ms = int((time.time() - start_time) * 1000) if start_time else None
return span_id, duration_ms
@@ -442,6 +470,10 @@ class TraceManager:
stack = _span_stack_ctx.get()
return stack[-1] if stack else None
@staticmethod
def get_root_agent_name() -> Optional[str]:
return _root_agent_name_ctx.get()
@staticmethod
def get_start_time(span_id: str) -> Optional[float]:
times = _span_times_ctx.get()
@@ -454,13 +486,10 @@ class TraceManager:
Returns:
True if this was the first token (newly recorded), False otherwise.
"""
first_tokens = _span_first_token_times_ctx.get()
if first_tokens is None:
first_tokens = {}
_span_first_token_times_ctx.set(first_tokens)
first_tokens = dict(_span_first_token_times_ctx.get() or {})
if span_id not in first_tokens:
first_tokens[span_id] = time.time()
_span_first_token_times_ctx.set(first_tokens)
return True
return False
@@ -1218,7 +1247,10 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
mode="NULLABLE",
description=(
"A JSON object containing arbitrary key-value pairs for"
" additional event metadata not covered by standard fields."
" additional event metadata. Includes enrichment fields like"
" 'root_agent_name' (turn orchestration), 'model' (request"
" model), 'model_version' (response version), and"
" 'usage_metadata' (detailed token counts)."
),
),
bigquery.SchemaField(
@@ -1420,7 +1452,6 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
# Use weakref to avoid circular references that prevent garbage collection
atexit.register(self._atexit_cleanup, weakref.proxy(self.batch_processor))
@staticmethod
@staticmethod
def _atexit_cleanup(batch_processor: "BatchProcessor") -> None:
"""Clean up batch processor on script exit."""
@@ -1563,7 +1594,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
async def _ensure_started(self, **kwargs) -> None:
"""Ensures that the plugin is started and initialized."""
if not self._started:
# Kept original lock name as it was not explicitly changed in the
# Kept original lock name as it was not explicitly changed.
if self._setup_lock is None:
self._setup_lock = asyncio.Lock()
async with self._setup_lock:
@@ -1660,6 +1691,28 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
status = kwargs.pop("status", "OK")
error_message = kwargs.pop("error_message", None)
# V2 Metadata Extensions
model = kwargs.pop("model", None)
model_version = kwargs.pop("model_version", None)
usage_metadata = kwargs.pop("usage_metadata", None)
# Add new fields to attributes instead of columns
kwargs["root_agent_name"] = TraceManager.get_root_agent_name()
if model:
kwargs["model"] = model
if model_version:
kwargs["model_version"] = model_version
if usage_metadata:
# Use smart truncate to handle Pydantic, Dataclasses, and other objects
usage_dict, _ = _recursive_smart_truncate(
usage_metadata, self.config.max_content_length
)
if isinstance(usage_dict, dict):
kwargs["usage_metadata"] = usage_dict
else:
# Fallback if it couldn't be converted to dict
kwargs["usage_metadata"] = usage_metadata
# Serialize remaining kwargs to JSON string for attributes
try:
attributes_json = json.dumps(kwargs)
@@ -1822,6 +1875,7 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
"LLM_REQUEST",
callback_context,
raw_content=llm_request,
model=llm_request.model,
**attributes,
)
@@ -1921,6 +1975,8 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
raw_content=content_str,
is_truncated=is_truncated,
latency_ms=duration,
model_version=llm_response.model_version,
usage_metadata=llm_response.usage_metadata,
span_id_override=span_id if is_popped else None,
parent_span_id_override=parent_span_id
if is_popped
@@ -15,6 +15,7 @@
from __future__ import annotations
import asyncio
import dataclasses
import json
from unittest import mock
@@ -150,6 +151,7 @@ def mock_write_client():
def dummy_arrow_schema():
return pa.schema([
pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False),
pa.field("root_agent_name", pa.string(), nullable=True),
pa.field("event_type", pa.string(), nullable=True),
pa.field("agent", pa.string(), nullable=True),
pa.field("session_id", pa.string(), nullable=True),
@@ -288,6 +290,36 @@ async def _get_captured_event_dict_async(mock_write_client, expected_schema):
return {k: v[0] for k, v in pydict.items()}
async def _get_captured_rows_async(mock_write_client, expected_schema):
"""Helper to get all rows passed to append_rows."""
all_rows = []
for call in mock_write_client.append_rows.call_args_list:
requests_iter = call.args[0]
requests = []
if hasattr(requests_iter, "__aiter__"):
async for req in requests_iter:
requests.append(req)
else:
requests = list(requests_iter)
for request in requests:
# Parse the Arrow batch back to a dict for verification
try:
reader = pa.ipc.open_stream(
request.arrow_rows.rows.serialized_record_batch
)
table = reader.read_all()
except Exception:
# Fallback: try reading as a single batch
buf = pa.py_buffer(request.arrow_rows.rows.serialized_record_batch)
batch = pa.ipc.read_record_batch(buf, expected_schema)
table = pa.Table.from_batches([batch])
pydict = table.to_pylist()
all_rows.extend(pydict)
return all_rows
def _assert_common_fields(log_entry, event_type, agent="MyTestAgent"):
assert log_entry["event_type"] == event_type
assert log_entry["agent"] == agent
@@ -315,6 +347,40 @@ def test_recursive_smart_truncate():
assert truncated["c"]["d"] == "long strin...[TRUNCATED]"
def test_recursive_smart_truncate_with_dataclasses():
"""Test recursive smart truncate with dataclasses."""
@dataclasses.dataclass
class LocalMissedKPI:
kpi: str
value: float
@dataclasses.dataclass
class LocalIncident:
id: str
kpi_missed: list[LocalMissedKPI]
status: str
incident = LocalIncident(
id="inc-123",
kpi_missed=[LocalMissedKPI(kpi="latency", value=99.9)],
status="active",
)
content = {"result": incident}
max_len = 1000
truncated, is_truncated = (
bigquery_agent_analytics_plugin._recursive_smart_truncate(
content, max_len
)
)
assert not is_truncated
assert isinstance(truncated["result"], dict)
assert truncated["result"]["id"] == "inc-123"
assert isinstance(truncated["result"]["kpi_missed"][0], dict)
assert truncated["result"]["kpi_missed"][0]["kpi"] == "latency"
# --- Test Class ---
@@ -344,7 +410,121 @@ class TestBigQueryAgentAnalyticsPlugin:
)
mock_auth_default.assert_not_called()
mock_bq_client.assert_not_called()
mock_write_client.append_rows.assert_not_called()
@pytest.mark.asyncio
async def test_enriched_metadata_logging(
self,
mock_auth_default,
mock_bq_client,
mock_write_client,
mock_to_arrow_schema,
dummy_arrow_schema,
callback_context,
):
# Setup
config = BigQueryLoggerConfig()
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, config=config
)
# Mock root agent
mock_root = mock.create_autospec(
base_agent.BaseAgent, instance=True, spec_set=True
)
type(mock_root).name = mock.PropertyMock(return_value="RootAgent")
callback_context._invocation_context.agent.root_agent = mock_root
# 1. Test root_agent_name and model extraction from request
llm_request = llm_request_lib.LlmRequest(
model="gemini-pro",
contents=[types.Content(parts=[types.Part(text="Hi")])],
)
await plugin.before_model_callback(
callback_context=callback_context, llm_request=llm_request
)
# 2. Test model_version and usage_metadata extraction from response
usage = types.GenerateContentResponseUsageMetadata(
prompt_token_count=10, candidates_token_count=20, total_token_count=30
)
llm_response = llm_response_lib.LlmResponse(
content=types.Content(parts=[types.Part(text="Hello")]),
usage_metadata=usage,
model_version="v1.2.3",
)
await plugin.after_model_callback(
callback_context=callback_context, llm_response=llm_response
)
await plugin.shutdown()
# Verify captured rows from mock client
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
assert len(rows) == 2
# Check LLM_REQUEST row
# Sort by event_type to ensure consistent indexing
rows.sort(key=lambda x: x["event_type"])
request_row = rows[0] # LLM_REQUEST
response_row = rows[1] # LLM_RESPONSE
assert request_row["event_type"] == "LLM_REQUEST"
attr_req = json.loads(request_row["attributes"])
assert attr_req["root_agent_name"] == "RootAgent"
assert attr_req["model"] == "gemini-pro"
# Check LLM_RESPONSE row
assert response_row["event_type"] == "LLM_RESPONSE"
attr_res = json.loads(response_row["attributes"])
assert attr_res["root_agent_name"] == "RootAgent"
assert attr_res["model_version"] == "v1.2.3"
usage_meta = attr_res["usage_metadata"]
assert "prompt_token_count" in usage_meta
assert usage_meta["prompt_token_count"] == 10
mock_write_client.append_rows.assert_called()
@pytest.mark.asyncio
async def test_concurrent_span_management(
self,
mock_auth_default,
mock_bq_client,
mock_write_client,
mock_to_arrow_schema,
callback_context,
):
# Setup
config = BigQueryLoggerConfig()
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, config=config
)
# Initialize trace in main context
bigquery_agent_analytics_plugin.TraceManager.init_trace(callback_context)
async def branch_1():
bigquery_agent_analytics_plugin.TraceManager.push_span(
callback_context, span_id="span-1"
)
await asyncio.sleep(0.02)
s_id = bigquery_agent_analytics_plugin.TraceManager.get_current_span_id()
bigquery_agent_analytics_plugin.TraceManager.pop_span()
return s_id
async def branch_2():
bigquery_agent_analytics_plugin.TraceManager.push_span(
callback_context, span_id="span-2"
)
await asyncio.sleep(0.02)
s_id = bigquery_agent_analytics_plugin.TraceManager.get_current_span_id()
bigquery_agent_analytics_plugin.TraceManager.pop_span()
return s_id
# Run concurrently
results = await asyncio.gather(branch_1(), branch_2())
# If they shared the same list/dict, they would interfere.
assert "span-1" in results
assert "span-2" in results
assert results[0] != results[1]
@pytest.mark.asyncio
async def test_event_allowlist(
@@ -1704,8 +1884,72 @@ class TestBigQueryAgentAnalyticsPlugin:
assert log_entry_resp["parent_span_id"] == agent_span_id
assert log_entry_resp["parent_span_id"] != log_entry_resp["span_id"]
# Verify Span was popped
current_span = (
# Verify LLM Span was popped and we are back to Agent Span
assert (
bigquery_agent_analytics_plugin.TraceManager.get_current_span_id()
== agent_span_id
)
assert current_span == agent_span_id
# Clean up Agent Span
bigquery_agent_analytics_plugin.TraceManager.pop_span()
assert (
not bigquery_agent_analytics_plugin.TraceManager.get_current_span_id()
)
@pytest.mark.asyncio
async def test_custom_object_serialization(
self,
mock_write_client,
tool_context,
mock_auth_default,
mock_bq_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
"""Verifies that custom objects (Dataclasses) are serialized to dicts."""
_ = mock_auth_default
_ = mock_bq_client
@dataclasses.dataclass
class LocalMissedKPI:
kpi: str
value: float
@dataclasses.dataclass
class LocalIncident:
id: str
kpi_missed: list[LocalMissedKPI]
status: str
incident = LocalIncident(
id="inc-123",
kpi_missed=[LocalMissedKPI(kpi="latency", value=99.9)],
status="active",
)
config = BigQueryLoggerConfig()
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
PROJECT_ID, DATASET_ID, table_id=TABLE_ID, config=config
)
await plugin._ensure_started()
mock_write_client.append_rows.reset_mock()
content = {"result": incident}
# Verify full flow
await plugin._log_event(
"TOOL_PARTIAL",
tool_context,
raw_content=content,
)
await asyncio.sleep(0.01)
mock_write_client.append_rows.assert_called_once()
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)
# Content should be valid JSON string
content_json = json.loads(log_entry["content"])
assert content_json["result"]["id"] == "inc-123"
assert content_json["result"]["kpi_missed"][0]["kpi"] == "latency"