You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add schema auto-upgrade, tool provenance, HITL tracing, and span hierarchy fix to BigQuery Agent Analytics plugin
This CL adds four enhancements to the BigQuery Agent Analytics plugin and fixes a span hierarchy corruption bug. - **Schema Auto-Upgrade:** Additive-only schema migration that automatically adds missing columns to existing BQ tables on startup. A `adk_schema_version` label on the table (starting at `"1"`, bumped with each schema change) makes the check idempotent — the diff runs at most once per version. Enabled by default (`auto_schema_upgrade=True`) because upgrades are additive-only and fail-safe. Pre-versioning tables (no label) are treated as outdated, diffed, and stamped. No previous schema versions need to be stored; the logic diffs actual columns against the current canonical schema. - **Tool Provenance:** Adds `tool_origin` to TOOL_* event content, distinguishing six origin types — `LOCAL` (FunctionTool), `MCP` (McpTool), `A2A` (AgentTool wrapping RemoteA2aAgent), `SUB_AGENT` (AgentTool), `TRANSFER_AGENT` (TransferToAgentTool), and `UNKNOWN` (fallback) — via `isinstance()` checks with lazy imports to avoid circular dependencies. - **HITL Tracing:** Emits dedicated HITL event types (`HITL_CONFIRMATION_REQUEST`, `HITL_CREDENTIAL_REQUEST`, `HITL_INPUT_REQUEST` + `_COMPLETED` variants) for human-in-the-loop interactions. Detection lives in `on_event_callback` (for synthetic `adk_request_*` FunctionCall events emitted by the framework) and `on_user_message_callback` (for `adk_request_*` FunctionResponse completions sent by the user), not in tool callbacks — because `adk_request_*` names are synthetic function calls that bypass `before_tool_callback`/`after_tool_callback` entirely. - **Span Hierarchy Fix (#4561):** Removes `context.attach()`/`context.detach()` calls from `TraceManager.push_span()`, `attach_current_span()`, and `pop_span()`. The plugin was injecting its spans into the shared OTel context, which corrupted the framework's span hierarchy when an external exporter (e.g. `opentelemetry-instrumentation-vertexai`) was active — causing `call_llm` to be re-parented under `llm_request` and parent spans to show shorter durations than children. The plugin now tracks span_id/parent_span_id via its internal contextvar stack without mutating ambient OTel context. Co-authored-by: Haiyuan Cao <haiyuan@google.com> PiperOrigin-RevId: 873114688
This commit is contained in:
committed by
Copybara-Service
parent
e8019b1b1b
commit
4260ef0c7c
@@ -51,7 +51,6 @@ from google.cloud.bigquery import schema as bq_schema
|
||||
from google.cloud.bigquery_storage_v1 import types as bq_storage_types
|
||||
from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient
|
||||
from google.genai import types
|
||||
from opentelemetry import context
|
||||
from opentelemetry import trace
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -71,6 +70,24 @@ tracer = trace.get_tracer(
|
||||
"google.adk.plugins.bigquery_agent_analytics", __version__
|
||||
)
|
||||
|
||||
# Bumped when the schema changes (1 → 2 → 3 …). Used as a table
|
||||
# label for governance and to decide whether auto-upgrade should run.
|
||||
_SCHEMA_VERSION = "1"
|
||||
_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version"
|
||||
|
||||
# Human-in-the-loop (HITL) tool names that receive additional
|
||||
# dedicated event types alongside the normal TOOL_* events.
|
||||
_HITL_TOOL_NAMES = frozenset({
|
||||
"adk_request_credential",
|
||||
"adk_request_confirmation",
|
||||
"adk_request_input",
|
||||
})
|
||||
_HITL_EVENT_MAP = MappingProxyType({
|
||||
"adk_request_credential": "HITL_CREDENTIAL_REQUEST",
|
||||
"adk_request_confirmation": "HITL_CONFIRMATION_REQUEST",
|
||||
"adk_request_input": "HITL_INPUT_REQUEST",
|
||||
})
|
||||
|
||||
|
||||
def _safe_callback(func):
|
||||
"""Decorator that catches and logs exceptions in plugin callbacks.
|
||||
@@ -132,6 +149,47 @@ def _format_content(
|
||||
return " | ".join(parts), truncated
|
||||
|
||||
|
||||
def _get_tool_origin(tool: "BaseTool") -> str:
|
||||
"""Returns the provenance category of a tool.
|
||||
|
||||
Uses lazy imports to avoid circular dependencies.
|
||||
|
||||
Args:
|
||||
tool: The tool instance.
|
||||
|
||||
Returns:
|
||||
One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN.
|
||||
"""
|
||||
# Import lazily to avoid circular dependencies.
|
||||
# pylint: disable=g-import-not-at-top
|
||||
from ..tools.agent_tool import AgentTool # pytype: disable=import-error
|
||||
from ..tools.function_tool import FunctionTool # pytype: disable=import-error
|
||||
from ..tools.transfer_to_agent_tool import TransferToAgentTool # pytype: disable=import-error
|
||||
|
||||
try:
|
||||
from ..tools.mcp_tool.mcp_tool import McpTool # pytype: disable=import-error
|
||||
except ImportError:
|
||||
McpTool = None
|
||||
|
||||
try:
|
||||
from ..agents.remote_a2a_agent import RemoteA2aAgent # pytype: disable=import-error
|
||||
except ImportError:
|
||||
RemoteA2aAgent = None
|
||||
|
||||
# Order matters: TransferToAgentTool is a subclass of FunctionTool.
|
||||
if McpTool is not None and isinstance(tool, McpTool):
|
||||
return "MCP"
|
||||
if isinstance(tool, TransferToAgentTool):
|
||||
return "TRANSFER_AGENT"
|
||||
if isinstance(tool, AgentTool):
|
||||
if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent):
|
||||
return "A2A"
|
||||
return "SUB_AGENT"
|
||||
if isinstance(tool, FunctionTool):
|
||||
return "LOCAL"
|
||||
return "UNKNOWN"
|
||||
|
||||
|
||||
def _recursive_smart_truncate(
|
||||
obj: Any, max_len: int, seen: Optional[set[int]] = None
|
||||
) -> tuple[Any, bool]:
|
||||
@@ -435,6 +493,11 @@ class BigQueryLoggerConfig:
|
||||
log_session_metadata: bool = True
|
||||
# Static custom tags (e.g. {"agent_role": "sales"})
|
||||
custom_tags: dict[str, Any] = field(default_factory=dict)
|
||||
# Automatically add new columns to existing tables when the plugin
|
||||
# schema evolves. Only additive changes are made (columns are never
|
||||
# dropped or altered). Safe to leave enabled; a version label on the
|
||||
# table ensures the diff runs at most once per schema version.
|
||||
auto_schema_upgrade: bool = True
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
@@ -450,12 +513,17 @@ _root_agent_name_ctx = contextvars.ContextVar(
|
||||
class _SpanRecord:
|
||||
"""A single record on the unified span stack.
|
||||
|
||||
Consolidates span, token, id, ownership, and timing into one object
|
||||
Consolidates span, id, ownership, and timing into one object
|
||||
so all stacks stay in sync by construction.
|
||||
|
||||
Note: The plugin intentionally does NOT attach its spans to the
|
||||
ambient OTel context (no ``context.attach``). This prevents the
|
||||
plugin from corrupting the framework's span hierarchy when an
|
||||
external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``)
|
||||
is active. See https://github.com/google/adk-python/issues/4561.
|
||||
"""
|
||||
|
||||
span: trace.Span
|
||||
token: Any # opentelemetry context token
|
||||
span_id: str
|
||||
owns_span: bool
|
||||
start_time_ns: int
|
||||
@@ -513,17 +581,26 @@ class TraceManager:
|
||||
|
||||
@staticmethod
|
||||
def push_span(
|
||||
callback_context: CallbackContext, span_name: Optional[str] = "adk-span"
|
||||
callback_context: CallbackContext,
|
||||
span_name: Optional[str] = "adk-span",
|
||||
) -> str:
|
||||
"""Starts a new span and pushes it onto the stack.
|
||||
|
||||
If OTel is not configured (returning non-recording spans), a UUID fallback
|
||||
is generated to ensure span_id and parent_span_id are populated in logs.
|
||||
The span is created but NOT attached to the ambient OTel context,
|
||||
so it cannot corrupt the framework's own span hierarchy. The
|
||||
plugin tracks span_id / parent_span_id internally via its own
|
||||
contextvar stack.
|
||||
|
||||
If OTel is not configured (returning non-recording spans), a UUID
|
||||
fallback is generated to ensure span_id and parent_span_id are
|
||||
populated in BigQuery logs.
|
||||
"""
|
||||
TraceManager.init_trace(callback_context)
|
||||
|
||||
# Create the span without attaching it to the ambient context.
|
||||
# This avoids re-parenting framework spans like ``call_llm``
|
||||
# or ``execute_tool``. See #4561.
|
||||
span = tracer.start_span(span_name)
|
||||
token = context.attach(trace.set_span_in_context(span))
|
||||
|
||||
if span.get_span_context().is_valid:
|
||||
span_id_str = format(span.get_span_context().span_id, "016x")
|
||||
@@ -532,7 +609,6 @@ class TraceManager:
|
||||
|
||||
record = _SpanRecord(
|
||||
span=span,
|
||||
token=token,
|
||||
span_id=span_id_str,
|
||||
owns_span=True,
|
||||
start_time_ns=time.time_ns(),
|
||||
@@ -548,11 +624,14 @@ class TraceManager:
|
||||
def attach_current_span(
|
||||
callback_context: CallbackContext,
|
||||
) -> str:
|
||||
"""Attaches the current OTEL span to the stack without owning it."""
|
||||
"""Records the current OTel span on the stack without owning it.
|
||||
|
||||
The span is NOT re-attached to the ambient context; it is only
|
||||
tracked internally for span_id / parent_span_id resolution.
|
||||
"""
|
||||
TraceManager.init_trace(callback_context)
|
||||
|
||||
span = trace.get_current_span()
|
||||
token = context.attach(trace.set_span_in_context(span))
|
||||
|
||||
if span.get_span_context().is_valid:
|
||||
span_id_str = format(span.get_span_context().span_id, "016x")
|
||||
@@ -561,7 +640,6 @@ class TraceManager:
|
||||
|
||||
record = _SpanRecord(
|
||||
span=span,
|
||||
token=token,
|
||||
span_id=span_id_str,
|
||||
owns_span=False,
|
||||
start_time_ns=time.time_ns(),
|
||||
@@ -575,7 +653,11 @@ class TraceManager:
|
||||
|
||||
@staticmethod
|
||||
def pop_span() -> tuple[Optional[str], Optional[int]]:
|
||||
"""Ends the current span and pops it from the stack."""
|
||||
"""Ends the current span and pops it from the stack.
|
||||
|
||||
No ambient OTel context is detached because we never attached
|
||||
one in the first place (see ``push_span``).
|
||||
"""
|
||||
records = _span_records_ctx.get()
|
||||
if not records:
|
||||
return None, None
|
||||
@@ -595,8 +677,6 @@ class TraceManager:
|
||||
if record.owns_span:
|
||||
record.span.end()
|
||||
|
||||
context.detach(record.token)
|
||||
|
||||
return record.span_id, duration_ms
|
||||
|
||||
@staticmethod
|
||||
@@ -1822,16 +1902,25 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
)
|
||||
|
||||
def _ensure_schema_exists(self) -> None:
|
||||
"""Ensures the BigQuery table exists with the correct schema."""
|
||||
"""Ensures the BigQuery table exists with the correct schema.
|
||||
|
||||
When ``config.auto_schema_upgrade`` is True and the table already
|
||||
exists, missing columns are added automatically (additive only).
|
||||
A ``adk_schema_version`` label is written for governance.
|
||||
"""
|
||||
try:
|
||||
self.client.get_table(self.full_table_id)
|
||||
existing_table = self.client.get_table(self.full_table_id)
|
||||
if self.config.auto_schema_upgrade:
|
||||
self._maybe_upgrade_schema(existing_table)
|
||||
except cloud_exceptions.NotFound:
|
||||
logger.info("Table %s not found, creating table.", self.full_table_id)
|
||||
tbl = bigquery.Table(self.full_table_id, schema=self._schema)
|
||||
tbl.time_partitioning = bigquery.TimePartitioning(
|
||||
type_=bigquery.TimePartitioningType.DAY, field="timestamp"
|
||||
type_=bigquery.TimePartitioningType.DAY,
|
||||
field="timestamp",
|
||||
)
|
||||
tbl.clustering_fields = self.config.clustering_fields
|
||||
tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION}
|
||||
try:
|
||||
self.client.create_table(tbl)
|
||||
except cloud_exceptions.Conflict:
|
||||
@@ -1851,6 +1940,46 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None:
|
||||
"""Adds missing columns to an existing table (additive only).
|
||||
|
||||
Args:
|
||||
existing_table: The current BigQuery table object.
|
||||
"""
|
||||
stored_version = (existing_table.labels or {}).get(
|
||||
_SCHEMA_VERSION_LABEL_KEY
|
||||
)
|
||||
if stored_version == _SCHEMA_VERSION:
|
||||
return
|
||||
|
||||
existing_names = {f.name for f in existing_table.schema}
|
||||
new_fields = [f for f in self._schema if f.name not in existing_names]
|
||||
|
||||
if new_fields:
|
||||
merged = list(existing_table.schema) + new_fields
|
||||
existing_table.schema = merged
|
||||
logger.info(
|
||||
"Auto-upgrading table %s: adding columns %s",
|
||||
self.full_table_id,
|
||||
[f.name for f in new_fields],
|
||||
)
|
||||
|
||||
# Always stamp the version label so we skip on next run.
|
||||
labels = dict(existing_table.labels or {})
|
||||
labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION
|
||||
existing_table.labels = labels
|
||||
|
||||
try:
|
||||
update_fields = ["schema", "labels"]
|
||||
self.client.update_table(existing_table, update_fields)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Schema auto-upgrade failed for %s: %s",
|
||||
self.full_table_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shuts down the plugin and releases resources.
|
||||
|
||||
@@ -2123,16 +2252,42 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
) -> None:
|
||||
"""Parity with V1: Logs USER_MESSAGE_RECEIVED event.
|
||||
|
||||
Also detects HITL completion responses (user-sent
|
||||
``FunctionResponse`` parts with ``adk_request_*`` names) and emits
|
||||
dedicated ``HITL_*_COMPLETED`` events.
|
||||
|
||||
Args:
|
||||
invocation_context: The context of the current invocation.
|
||||
user_message: The message content received from the user.
|
||||
"""
|
||||
callback_ctx = CallbackContext(invocation_context)
|
||||
await self._log_event(
|
||||
"USER_MESSAGE_RECEIVED",
|
||||
CallbackContext(invocation_context),
|
||||
callback_ctx,
|
||||
raw_content=user_message,
|
||||
)
|
||||
|
||||
# Detect HITL completion responses in the user message.
|
||||
if user_message and user_message.parts:
|
||||
for part in user_message.parts:
|
||||
if part.function_response:
|
||||
hitl_event = _HITL_EVENT_MAP.get(part.function_response.name)
|
||||
if hitl_event:
|
||||
resp_truncated, is_truncated = _recursive_smart_truncate(
|
||||
part.function_response.response or {},
|
||||
self.config.max_content_length,
|
||||
)
|
||||
content_dict = {
|
||||
"tool": part.function_response.name,
|
||||
"result": resp_truncated,
|
||||
}
|
||||
await self._log_event(
|
||||
hitl_event + "_COMPLETED",
|
||||
callback_ctx,
|
||||
raw_content=content_dict,
|
||||
is_truncated=is_truncated,
|
||||
)
|
||||
|
||||
@_safe_callback
|
||||
async def on_event_callback(
|
||||
self,
|
||||
@@ -2140,24 +2295,76 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
invocation_context: InvocationContext,
|
||||
event: "Event",
|
||||
) -> None:
|
||||
"""Logs state changes from events to BigQuery.
|
||||
"""Logs state changes and HITL events from the event stream.
|
||||
|
||||
Checks each event for a non-empty state_delta and logs it as a
|
||||
STATE_DELTA event. This captures state changes from all sources
|
||||
(tools, agents, LLM, manual), not just tool callbacks.
|
||||
- Checks each event for a non-empty state_delta and logs it as a
|
||||
STATE_DELTA event.
|
||||
- Detects synthetic ``adk_request_*`` function calls (HITL pause
|
||||
events) and their corresponding function responses (HITL
|
||||
completions) and emits dedicated HITL event types.
|
||||
|
||||
The HITL detection must happen here (not in tool callbacks) because
|
||||
``adk_request_credential``, ``adk_request_confirmation``, and
|
||||
``adk_request_input`` are synthetic function calls injected by the
|
||||
framework — they never go through ``before_tool_callback`` /
|
||||
``after_tool_callback``.
|
||||
|
||||
Args:
|
||||
invocation_context: The context for the current invocation.
|
||||
event: The event raised by the runner.
|
||||
"""
|
||||
callback_ctx = CallbackContext(invocation_context)
|
||||
|
||||
# --- State delta logging ---
|
||||
if event.actions and event.actions.state_delta:
|
||||
await self._log_event(
|
||||
"STATE_DELTA",
|
||||
CallbackContext(invocation_context),
|
||||
callback_ctx,
|
||||
event_data=EventData(
|
||||
extra_attributes={"state_delta": dict(event.actions.state_delta)}
|
||||
),
|
||||
)
|
||||
|
||||
# --- HITL event logging ---
|
||||
if event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
# Detect HITL function calls (request events).
|
||||
if part.function_call:
|
||||
hitl_event = _HITL_EVENT_MAP.get(part.function_call.name)
|
||||
if hitl_event:
|
||||
args_truncated, is_truncated = _recursive_smart_truncate(
|
||||
part.function_call.args or {},
|
||||
self.config.max_content_length,
|
||||
)
|
||||
content_dict = {
|
||||
"tool": part.function_call.name,
|
||||
"args": args_truncated,
|
||||
}
|
||||
await self._log_event(
|
||||
hitl_event,
|
||||
callback_ctx,
|
||||
raw_content=content_dict,
|
||||
is_truncated=is_truncated,
|
||||
)
|
||||
# Detect HITL function responses (completion events).
|
||||
if part.function_response:
|
||||
hitl_event = _HITL_EVENT_MAP.get(part.function_response.name)
|
||||
if hitl_event:
|
||||
resp_truncated, is_truncated = _recursive_smart_truncate(
|
||||
part.function_response.response or {},
|
||||
self.config.max_content_length,
|
||||
)
|
||||
content_dict = {
|
||||
"tool": part.function_response.name,
|
||||
"result": resp_truncated,
|
||||
}
|
||||
await self._log_event(
|
||||
hitl_event + "_COMPLETED",
|
||||
callback_ctx,
|
||||
raw_content=content_dict,
|
||||
is_truncated=is_truncated,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def on_state_change_callback(
|
||||
@@ -2460,7 +2667,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
args_truncated, is_truncated = _recursive_smart_truncate(
|
||||
tool_args, self.config.max_content_length
|
||||
)
|
||||
content_dict = {"tool": tool.name, "args": args_truncated}
|
||||
tool_origin = _get_tool_origin(tool)
|
||||
content_dict = {
|
||||
"tool": tool.name,
|
||||
"args": args_truncated,
|
||||
"tool_origin": tool_origin,
|
||||
}
|
||||
TraceManager.push_span(tool_context, "tool")
|
||||
await self._log_event(
|
||||
"TOOL_STARTING",
|
||||
@@ -2489,20 +2701,26 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
resp_truncated, is_truncated = _recursive_smart_truncate(
|
||||
result, self.config.max_content_length
|
||||
)
|
||||
content_dict = {"tool": tool.name, "result": resp_truncated}
|
||||
tool_origin = _get_tool_origin(tool)
|
||||
content_dict = {
|
||||
"tool": tool.name,
|
||||
"result": resp_truncated,
|
||||
"tool_origin": tool_origin,
|
||||
}
|
||||
span_id, duration = TraceManager.pop_span()
|
||||
parent_span_id, _ = TraceManager.get_current_span_and_parent()
|
||||
|
||||
event_data = EventData(
|
||||
latency_ms=duration,
|
||||
span_id_override=span_id,
|
||||
parent_span_id_override=parent_span_id,
|
||||
)
|
||||
await self._log_event(
|
||||
"TOOL_COMPLETED",
|
||||
tool_context,
|
||||
raw_content=content_dict,
|
||||
is_truncated=is_truncated,
|
||||
event_data=EventData(
|
||||
latency_ms=duration,
|
||||
span_id_override=span_id,
|
||||
parent_span_id_override=parent_span_id,
|
||||
),
|
||||
event_data=event_data,
|
||||
)
|
||||
|
||||
@_safe_callback
|
||||
@@ -2525,7 +2743,12 @@ class BigQueryAgentAnalyticsPlugin(BasePlugin):
|
||||
args_truncated, is_truncated = _recursive_smart_truncate(
|
||||
tool_args, self.config.max_content_length
|
||||
)
|
||||
content_dict = {"tool": tool.name, "args": args_truncated}
|
||||
tool_origin = _get_tool_origin(tool)
|
||||
content_dict = {
|
||||
"tool": tool.name,
|
||||
"args": args_truncated,
|
||||
"tool_origin": tool_origin,
|
||||
}
|
||||
_, duration = TraceManager.pop_span()
|
||||
await self._log_event(
|
||||
"TOOL_ERROR",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user