fix: set execute_tool {tool.name} span attributes even when exception occurs during tool's execution

PiperOrigin-RevId: 824165197
This commit is contained in:
Max Ind
2025-10-26 06:06:44 -07:00
committed by Copybara-Service
parent 5d9a7e7f79
commit e6f7363220
3 changed files with 103 additions and 39 deletions
+41 -15
View File
@@ -305,7 +305,9 @@ async def _execute_single_function_call_async(
else:
raise tool_error
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
async def _run_with_trace():
nonlocal function_args
# Step 1: Check if plugin before_tool_callback overrides the function
# response.
function_response = (
@@ -391,13 +393,23 @@ async def _execute_single_function_call_async(
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
trace_tool_call(
tool=tool,
args=function_args,
function_response_event=function_response_event,
)
return function_response_event
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
try:
function_response_event = await _run_with_trace()
trace_tool_call(
tool=tool,
args=function_args,
function_response_event=function_response_event,
)
return function_response_event
except:
trace_tool_call(
tool=tool, args=function_args, function_response_event=None
)
raise
async def handle_function_calls_live(
invocation_context: InvocationContext,
@@ -467,13 +479,17 @@ async def _execute_single_function_call_live(
tool, tool_context = _get_tool_and_context(
invocation_context, function_call, tools_dict
)
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)
async def _run_with_trace():
nonlocal function_args
# Do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
# Make a deep copy to avoid being modified.
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)
function_response = None
# Handle before_tool_callbacks - iterate through the canonical callback
@@ -527,13 +543,23 @@ async def _execute_single_function_call_live(
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
)
trace_tool_call(
tool=tool,
args=function_args,
function_response_event=function_response_event,
)
return function_response_event
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
try:
function_response_event = await _run_with_trace()
trace_tool_call(
tool=tool,
args=function_args,
function_response_event=function_response_event,
)
return function_response_event
except:
trace_tool_call(
tool=tool, args=function_args, function_response_event=None
)
raise
async def _process_function_live_helper(
tool,
+6 -3
View File
@@ -26,6 +26,7 @@ from __future__ import annotations
import json
import os
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
@@ -118,7 +119,7 @@ def trace_agent_invocation(
def trace_tool_call(
tool: BaseTool,
args: dict[str, Any],
function_response_event: Event,
function_response_event: Optional[Event],
):
"""Traces tool call.
@@ -154,7 +155,8 @@ def trace_tool_call(
tool_call_id = '<not specified>'
tool_response = '<not specified>'
if (
function_response_event.content is not None
function_response_event is not None
and function_response_event.content is not None
and function_response_event.content.parts
):
response_parts = function_response_event.content.parts
@@ -169,7 +171,8 @@ def trace_tool_call(
if not isinstance(tool_response, dict):
tool_response = {'result': tool_response}
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
if function_response_event is not None:
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
if _should_add_request_response_to_spans():
span.set_attribute(
'gcp.vertex.agent.tool_response',
+56 -21
View File
@@ -12,18 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import gc
import sys
from unittest import mock
from google.adk.agents import base_agent
from google.adk.agents.llm_agent import Agent
from google.adk.models.base_llm import BaseLlm
from google.adk.models.llm_response import LlmResponse
from google.adk.telemetry import tracing
from google.adk.tools import FunctionTool
from google.adk.utils.context_utils import Aclosing
from google.genai.types import Content
from google.genai.types import Part
from opentelemetry.version import __version__
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
import pytest
from ..testing_utils import MockModel
@@ -63,27 +67,27 @@ async def test_runner(test_agent: Agent) -> TestInMemoryRunner:
@pytest.fixture
def mock_start_as_current_span(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
mock_context_manager = mock.MagicMock()
mock_context_manager.__enter__.return_value = mock.Mock()
mock_start_as_current_span = mock.Mock()
mock_start_as_current_span.return_value = mock_context_manager
def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter:
tracer_provider = TracerProvider()
span_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter))
real_tracer = tracer_provider.get_tracer(__name__)
def do_replace(tracer):
monkeypatch.setattr(
tracer, 'start_as_current_span', mock_start_as_current_span
tracer, 'start_as_current_span', real_tracer.start_as_current_span
)
do_replace(tracing.tracer)
do_replace(base_agent.tracer)
return mock_start_as_current_span
return span_exporter
@pytest.mark.asyncio
async def test_tracer_start_as_current_span(
test_runner: TestInMemoryRunner,
mock_start_as_current_span: mock.Mock,
span_exporter: InMemorySpanExporter,
):
"""Test creation of multiple spans in an E2E runner invocation.
@@ -112,18 +116,49 @@ async def test_tracer_start_as_current_span(
pass
# Assert
expected_start_as_current_span_calls = [
mock.call('invocation'),
mock.call('execute_tool some_tool'),
mock.call('invoke_agent some_root_agent'),
mock.call('call_llm'),
mock.call('call_llm'),
spans = span_exporter.get_finished_spans()
assert list(sorted(span.name for span in spans)) == [
'call_llm',
'call_llm',
'execute_tool some_tool',
'invocation',
'invoke_agent some_root_agent',
]
mock_start_as_current_span.assert_has_calls(
expected_start_as_current_span_calls,
any_order=True,
@pytest.mark.asyncio
async def test_exception_preserves_attributes(
test_model: BaseLlm, span_exporter: InMemorySpanExporter
):
"""Test when an exception occurs during tool execution, span attributes are still present on spans where they are expected."""
# Arrange
async def some_tool():
raise ValueError('This tool always fails')
test_agent = Agent(
name='some_root_agent',
model=test_model,
tools=[
FunctionTool(some_tool),
],
)
assert mock_start_as_current_span.call_count == len(
expected_start_as_current_span_calls
test_runner = TestInMemoryRunner(test_agent)
# Act
with pytest.raises(ValueError, match='This tool always fails'):
async with Aclosing(
test_runner.run_async_with_new_session_agen('')
) as agen:
async for _ in agen:
pass
# Assert
spans = span_exporter.get_finished_spans()
assert len(spans) > 1
assert all(
span.attributes is not None and len(span.attributes) > 0
for span in spans
if span.name != 'invocation' # not expected to have attributes
)