feat: Add schema auto-upgrade, tool provenance, HITL tracing, and span hierarchy fix to BigQuery Agent Analytics plugin

This CL adds four enhancements to the BigQuery Agent Analytics plugin and fixes a span hierarchy corruption bug.

- **Schema Auto-Upgrade:** Additive-only schema migration that automatically adds missing columns to existing BQ tables on startup. A `adk_schema_version` label on the table (starting at `"1"`, bumped with each schema change) makes the check idempotent — the diff runs at most once per version. Enabled by default (`auto_schema_upgrade=True`) because upgrades are additive-only and fail-safe. Pre-versioning tables (no label) are treated as outdated, diffed, and stamped. No previous schema versions need to be stored; the logic diffs actual columns against the current canonical schema.

- **Tool Provenance:** Adds `tool_origin` to TOOL_* event content, distinguishing six origin types — `LOCAL` (FunctionTool), `MCP` (McpTool), `A2A` (AgentTool wrapping RemoteA2aAgent), `SUB_AGENT` (AgentTool), `TRANSFER_AGENT` (TransferToAgentTool), and `UNKNOWN` (fallback) — via `isinstance()` checks with lazy imports to avoid circular dependencies.

- **HITL Tracing:** Emits dedicated HITL event types (`HITL_CONFIRMATION_REQUEST`, `HITL_CREDENTIAL_REQUEST`, `HITL_INPUT_REQUEST` + `_COMPLETED` variants) for human-in-the-loop interactions. Detection lives in `on_event_callback` (for synthetic `adk_request_*` FunctionCall events emitted by the framework) and `on_user_message_callback` (for `adk_request_*` FunctionResponse completions sent by the user), not in tool callbacks — because `adk_request_*` names are synthetic function calls that bypass `before_tool_callback`/`after_tool_callback` entirely.

- **Span Hierarchy Fix (#4561):** Removes `context.attach()`/`context.detach()` calls from `TraceManager.push_span()`, `attach_current_span()`, and `pop_span()`. The plugin was injecting its spans into the shared OTel context, which corrupted the framework's span hierarchy when an external exporter (e.g. `opentelemetry-instrumentation-vertexai`) was active — causing `call_llm` to be re-parented under `llm_request` and parent spans to show shorter durations than children. The plugin now tracks span_id/parent_span_id via its internal contextvar stack without mutating ambient OTel context.

Co-authored-by: Haiyuan Cao <haiyuan@google.com>
PiperOrigin-RevId: 873114688
This commit is contained in:
Haiyuan Cao
2026-02-20 16:11:31 -08:00
committed by Copybara-Service
parent e8019b1b1b
commit 4260ef0c7c
2 changed files with 970 additions and 31 deletions
@@ -51,7 +51,6 @@ from google.cloud.bigquery import schema as bq_schema
from google.cloud.bigquery_storage_v1 import types as bq_storage_types
from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient
from google.genai import types
from opentelemetry import context
from opentelemetry import trace
import pyarrow as pa
@@ -71,6 +70,24 @@ tracer = trace.get_tracer(
"google.adk.plugins.bigquery_agent_analytics", __version__
)
# Bumped when the schema changes (1 → 2 → 3 …). Used as a table
# label for governance and to decide whether auto-upgrade should run.
_SCHEMA_VERSION = "1"
_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version"
# Human-in-the-loop (HITL) tool names that receive additional
# dedicated event types alongside the normal TOOL_* events.
_HITL_TOOL_NAMES = frozenset({
"adk_request_credential",
"adk_request_confirmation",
"adk_request_input",
})
_HITL_EVENT_MAP = MappingProxyType({
"adk_request_credential": "HITL_CREDENTIAL_REQUEST",
"adk_request_confirmation": "HITL_CONFIRMATION_REQUEST",
"adk_request_input": "HITL_INPUT_REQUEST",
})
def _safe_callback(func):
"""Decorator that catches and logs exceptions in plugin callbacks.
@@ -132,6 +149,47 @@ def _format_content(
return " | ".join(parts), truncated
def _get_tool_origin(tool: "BaseTool") -> str:
"""Returns the provenance category of a tool.
Uses lazy imports to avoid circular dependencies.
Args:
tool: The tool instance.
Returns:
One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN.
"""
# Import lazily to avoid circular dependencies.
# pylint: disable=g-import-not-at-top
from ..tools.agent_tool import AgentTool # pytype: disable=import-error
from ..tools.function_tool import FunctionTool # pytype: disable=import-error
from ..tools.transfer_to_agent_tool import TransferToAgentTool # pytype: disable=import-error
try:
from ..tools.mcp_tool.mcp_tool import McpTool # pytype: disable=import-error
except ImportError:
McpTool = None
try:
from ..agents.remote_a2a_agent import RemoteA2aAgent # pytype: disable=import-error
except ImportError:
RemoteA2aAgent = None
# Order matters: TransferToAgentTool is a subclass of FunctionTool.
if McpTool is not None and isinstance(tool, McpTool):
return "MCP"
if isinstance(tool, TransferToAgentTool):
return "TRANSFER_AGENT"
if isinstance(tool, AgentTool):
if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent):
return "A2A"
return "SUB_AGENT"
if isinstance(tool, FunctionTool):
return "LOCAL"
return "UNKNOWN"
def _recursive_smart_truncate(
obj: Any, max_len: int, seen: Optional[set[int]] = None
) -> tuple[Any, bool]:
@@ -435,6 +493,11 @@ class BigQueryLoggerConfig:
log_session_metadata: bool = True
# Static custom tags (e.g. {"agent_role": "sales"})
custom_tags: dict[str, Any] = field(default_factory=dict)
# Automatically add new columns to existing tables when the plugin
# schema evolves. Only additive changes are made (columns are never
# dropped or altered). Safe to leave enabled; a version label on the
# table ensures the diff runs at most once per schema version.
auto_schema_upgrade: bool = True
# ==============================================================================
@@ -450,12 +513,17 @@ _root_agent_name_ctx = contextvars.ContextVar(
class _SpanRecord:
"""A single record on the unified span stack.
Consolidates span, token, id, ownership, and timing into one object
Consolidates span, id, ownership, and timing into one object
so all stacks stay in sync by construction.
Note: The plugin intentionally does NOT attach its spans to the
ambient OTel context (no ``context.attach``). This prevents the
plugin from corrupting the framework's span hierarchy when an
external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``)
is active. See https://github.com/google/adk-python/issues/4561.
"""
span: trace.Span
token: Any # opentelemetry context token
span_id: str
owns_span: bool
start_time_ns: int
@@ -513,17 +581,26 @@ class TraceManager:
@staticmethod
def push_span(
callback_context: CallbackContext, span_name: Optional[str] = "adk-span"
callback_context: CallbackContext,
span_name: Optional[str] = "adk-span",
) -> str:
"""Starts a new span and pushes it onto the stack.
If OTel is not configured (returning non-recording spans), a UUID fallback
is generated to ensure span_id and parent_span_id are populated in logs.
The span is created but NOT attached to the ambient OTel context,
so it cannot corrupt the framework's own span hierarchy. The
plugin tracks span_id / parent_span_id internally via its own
contextvar stack.
If OTel is not configured (returning non-recording spans), a UUID
fallback is generated to ensure span_id and parent_span_id are
populated in BigQuery logs.
"""
TraceManager.init_trace(callback_context)
# Create the span without attaching it to the ambient context.
# This avoids re-parenting framework spans like ``call_llm``
# or ``execute_tool``. See #4561.
span = tracer.start_span(span_name)
token = context.attach(trace.set_span_in_context(span))
if span.get_span_context().is_valid:
span_id_str = format(span.get_span_context().span_id, "016x")
@@ -532,7 +609,6 @@ class TraceManager:
record = _SpanRecord(
span=span,
token=token,
span_id=span_id_str,
owns_span=True,
start_time_ns=time.time_ns(),
@@ -548,11 +624,14 @@ class TraceManager:
def attach_current_span(
callback_context: CallbackContext,
) -> str:
"""Attaches the current OTEL span to the stack without owning it."""
"""Records the current OTel span on the stack without owning it.
The span is NOT re-attached to the ambient context; it is only
tracked internally for span_id / parent_span_id resolution.
"""
TraceManager.init_trace(callback_context)
span = trace.get_current_span()
token = context.attach(trace.set_span_in_context(span))
if span.get_span_context().is_valid:
span_id_str = format(span.get_span_context().span_id, "016x")
@@ -561,7 +640,6 @@ class TraceManager:
record = _SpanRecord(
span=span,
token=token,
span_id=span_id_str,
owns_span=False,
start_time_ns=time.time_ns(),
@@ -575,7 +653,11 @@ class TraceManager:
@staticmethod
def pop_span() -> tuple[Optional[str], Optional[int]]:
"""Ends the current span and pops it from the stack."""
"""Ends the current span and pops it from the stack.
No ambient OTel context is detached because we never attached
one in the first place (see ``push_span``).
"""
records = _span_records_ctx.get()
if not records:
return None, None
@@ -595,8 +677,6 @@ class TraceManager:
if record.owns_span:
record.span.end()
context.detach(record.token)
return record.span_id, duration_ms
@staticmethod
@@ -1822,16 +1902,25 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
)
def _ensure_schema_exists(self) -> None:
"""Ensures the BigQuery table exists with the correct schema."""
"""Ensures the BigQuery table exists with the correct schema.
When ``config.auto_schema_upgrade`` is True and the table already
exists, missing columns are added automatically (additive only).
A ``adk_schema_version`` label is written for governance.
"""
try:
self.client.get_table(self.full_table_id)
existing_table = self.client.get_table(self.full_table_id)
if self.config.auto_schema_upgrade:
self._maybe_upgrade_schema(existing_table)
except cloud_exceptions.NotFound:
logger.info("Table %s not found, creating table.", self.full_table_id)
tbl = bigquery.Table(self.full_table_id, schema=self._schema)
tbl.time_partitioning = bigquery.TimePartitioning(
type_=bigquery.TimePartitioningType.DAY, field="timestamp"
type_=bigquery.TimePartitioningType.DAY,
field="timestamp",
)
tbl.clustering_fields = self.config.clustering_fields
tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION}
try:
self.client.create_table(tbl)
except cloud_exceptions.Conflict:
@@ -1851,6 +1940,46 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
exc_info=True,
)
def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
"""Adds missing columns to an existing table (additive only).
Args:
existing_table: The current BigQuery table object.
"""
stored_version = (existing_table.labels or {}).get(
_SCHEMA_VERSION_LABEL_KEY
)
if stored_version == _SCHEMA_VERSION:
return
existing_names = {f.name for f in existing_table.schema}
new_fields = [f for f in self._schema if f.name not in existing_names]
if new_fields:
merged = list(existing_table.schema) + new_fields
existing_table.schema = merged
logger.info(
"Auto-upgrading table %s: adding columns %s",
self.full_table_id,
[f.name for f in new_fields],
)
# Always stamp the version label so we skip on next run.
labels = dict(existing_table.labels or {})
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
existing_table.labels = labels
try:
update_fields = ["schema", "labels"]
self.client.update_table(existing_table, update_fields)
except Exception as e:
logger.error(
"Schema auto-upgrade failed for %s: %s",
self.full_table_id,
e,
exc_info=True,
)
async def shutdown(self, timeout: float | None = None) -> None:
"""Shuts down the plugin and releases resources.
@@ -2123,16 +2252,42 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
) -> None:
"""Parity with V1: Logs USER_MESSAGE_RECEIVED event.
Also detects HITL completion responses (user-sent
``FunctionResponse`` parts with ``adk_request_*`` names) and emits
dedicated ``HITL_*_COMPLETED`` events.
Args:
invocation_context: The context of the current invocation.
user_message: The message content received from the user.
"""
callback_ctx = CallbackContext(invocation_context)
await self._log_event(
"USER_MESSAGE_RECEIVED",
CallbackContext(invocation_context),
callback_ctx,
raw_content=user_message,
)
# Detect HITL completion responses in the user message.
if user_message and user_message.parts:
for part in user_message.parts:
if part.function_response:
hitl_event = _HITL_EVENT_MAP.get(part.function_response.name)
if hitl_event:
resp_truncated, is_truncated = _recursive_smart_truncate(
part.function_response.response or {},
self.config.max_content_length,
)
content_dict = {
"tool": part.function_response.name,
"result": resp_truncated,
}
await self._log_event(
hitl_event + "_COMPLETED",
callback_ctx,
raw_content=content_dict,
is_truncated=is_truncated,
)
@_safe_callback
async def on_event_callback(
self,
@@ -2140,24 +2295,76 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
invocation_context: InvocationContext,
event: "Event",
) -> None:
"""Logs state changes from events to BigQuery.
"""Logs state changes and HITL events from the event stream.
Checks each event for a non-empty state_delta and logs it as a
STATE_DELTA event. This captures state changes from all sources
(tools, agents, LLM, manual), not just tool callbacks.
- Checks each event for a non-empty state_delta and logs it as a
STATE_DELTA event.
- Detects synthetic ``adk_request_*`` function calls (HITL pause
events) and their corresponding function responses (HITL
completions) and emits dedicated HITL event types.
The HITL detection must happen here (not in tool callbacks) because
``adk_request_credential``, ``adk_request_confirmation``, and
``adk_request_input`` are synthetic function calls injected by the
framework — they never go through ``before_tool_callback`` /
``after_tool_callback``.
Args:
invocation_context: The context for the current invocation.
event: The event raised by the runner.
"""
callback_ctx = CallbackContext(invocation_context)
# --- State delta logging ---
if event.actions and event.actions.state_delta:
await self._log_event(
"STATE_DELTA",
CallbackContext(invocation_context),
callback_ctx,
event_data=EventData(
extra_attributes={"state_delta": dict(event.actions.state_delta)}
),
)
# --- HITL event logging ---
if event.content and event.content.parts:
for part in event.content.parts:
# Detect HITL function calls (request events).
if part.function_call:
hitl_event = _HITL_EVENT_MAP.get(part.function_call.name)
if hitl_event:
args_truncated, is_truncated = _recursive_smart_truncate(
part.function_call.args or {},
self.config.max_content_length,
)
content_dict = {
"tool": part.function_call.name,
"args": args_truncated,
}
await self._log_event(
hitl_event,
callback_ctx,
raw_content=content_dict,
is_truncated=is_truncated,
)
# Detect HITL function responses (completion events).
if part.function_response:
hitl_event = _HITL_EVENT_MAP.get(part.function_response.name)
if hitl_event:
resp_truncated, is_truncated = _recursive_smart_truncate(
part.function_response.response or {},
self.config.max_content_length,
)
content_dict = {
"tool": part.function_response.name,
"result": resp_truncated,
}
await self._log_event(
hitl_event + "_COMPLETED",
callback_ctx,
raw_content=content_dict,
is_truncated=is_truncated,
)
return None
async def on_state_change_callback(
@@ -2460,7 +2667,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
args_truncated, is_truncated = _recursive_smart_truncate(
tool_args, self.config.max_content_length
)
content_dict = {"tool": tool.name, "args": args_truncated}
tool_origin = _get_tool_origin(tool)
content_dict = {
"tool": tool.name,
"args": args_truncated,
"tool_origin": tool_origin,
}
TraceManager.push_span(tool_context, "tool")
await self._log_event(
"TOOL_STARTING",
@@ -2489,20 +2701,26 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
resp_truncated, is_truncated = _recursive_smart_truncate(
result, self.config.max_content_length
)
content_dict = {"tool": tool.name, "result": resp_truncated}
tool_origin = _get_tool_origin(tool)
content_dict = {
"tool": tool.name,
"result": resp_truncated,
"tool_origin": tool_origin,
}
span_id, duration = TraceManager.pop_span()
parent_span_id, _ = TraceManager.get_current_span_and_parent()
event_data = EventData(
latency_ms=duration,
span_id_override=span_id,
parent_span_id_override=parent_span_id,
)
await self._log_event(
"TOOL_COMPLETED",
tool_context,
raw_content=content_dict,
is_truncated=is_truncated,
event_data=EventData(
latency_ms=duration,
span_id_override=span_id,
parent_span_id_override=parent_span_id,
),
event_data=event_data,
)
@_safe_callback
@@ -2525,7 +2743,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
args_truncated, is_truncated = _recursive_smart_truncate(
tool_args, self.config.max_content_length
)
content_dict = {"tool": tool.name, "args": args_truncated}
tool_origin = _get_tool_origin(tool)
content_dict = {
"tool": tool.name,
"args": args_truncated,
"tool_origin": tool_origin,
}
_, duration = TraceManager.pop_span()
await self._log_event(
"TOOL_ERROR",
File diff suppressed because it is too large Load Diff