From e6f7363220abeae4ec120148a81bc5a4bd4486ef Mon Sep 17 00:00:00 2001 From: Max Ind Date: Sun, 26 Oct 2025 06:06:44 -0700 Subject: [PATCH] fix: set `execute_tool {tool.name}` span attributes even when exception occurs during tool's execution PiperOrigin-RevId: 824165197 --- src/google/adk/flows/llm_flows/functions.py | 56 ++++++++++---- src/google/adk/telemetry/tracing.py | 9 ++- tests/unittests/telemetry/test_functional.py | 77 ++++++++++++++------ 3 files changed, 103 insertions(+), 39 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 4380322b..40e0be25 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -305,7 +305,9 @@ async def _execute_single_function_call_async( else: raise tool_error - with tracer.start_as_current_span(f'execute_tool {tool.name}'): + async def _run_with_trace(): + nonlocal function_args + # Step 1: Check if plugin before_tool_callback overrides the function # response. function_response = ( @@ -391,13 +393,23 @@ async def _execute_single_function_call_async( function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) return function_response_event + with tracer.start_as_current_span(f'execute_tool {tool.name}'): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise + async def handle_function_calls_live( invocation_context: InvocationContext, @@ -467,13 +479,17 @@ async def _execute_single_function_call_live( tool, tool_context = _get_tool_and_context( invocation_context, function_call, tools_dict ) - with tracer.start_as_current_span(f'execute_tool {tool.name}'): + + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + async def _run_with_trace(): + nonlocal function_args + # Do not use "args" as the variable name, because it is a reserved keyword # in python debugger. # Make a deep copy to avoid being modified. - function_args = ( - copy.deepcopy(function_call.args) if function_call.args else {} - ) function_response = None # Handle before_tool_callbacks - iterate through the canonical callback @@ -527,13 +543,23 @@ async def _execute_single_function_call_live( function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) - trace_tool_call( - tool=tool, - args=function_args, - function_response_event=function_response_event, - ) return function_response_event + with tracer.start_as_current_span(f'execute_tool {tool.name}'): + try: + function_response_event = await _run_with_trace() + trace_tool_call( + tool=tool, + args=function_args, + function_response_event=function_response_event, + ) + return function_response_event + except: + trace_tool_call( + tool=tool, args=function_args, function_response_event=None + ) + raise + async def _process_function_live_helper( tool, diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index cc0a2a23..021471a5 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -26,6 +26,7 @@ from __future__ import annotations import json import os from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types @@ -118,7 +119,7 @@ def trace_agent_invocation( def trace_tool_call( tool: BaseTool, args: dict[str, Any], - function_response_event: Event, + function_response_event: Optional[Event], ): """Traces tool call. @@ -154,7 +155,8 @@ def trace_tool_call( tool_call_id = '' tool_response = '' if ( - function_response_event.content is not None + function_response_event is not None + and function_response_event.content is not None and function_response_event.content.parts ): response_parts = function_response_event.content.parts @@ -169,7 +171,8 @@ def trace_tool_call( if not isinstance(tool_response, dict): tool_response = {'result': tool_response} - span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) + if function_response_event is not None: + span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id) if _should_add_request_response_to_spans(): span.set_attribute( 'gcp.vertex.agent.tool_response', diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py index 100ec40a..f3b29c54 100644 --- a/tests/unittests/telemetry/test_functional.py +++ b/tests/unittests/telemetry/test_functional.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import gc import sys -from unittest import mock from google.adk.agents import base_agent from google.adk.agents.llm_agent import Agent from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_response import LlmResponse from google.adk.telemetry import tracing from google.adk.tools import FunctionTool from google.adk.utils.context_utils import Aclosing +from google.genai.types import Content from google.genai.types import Part -from opentelemetry.version import __version__ +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter import pytest from ..testing_utils import MockModel @@ -63,27 +67,27 @@ async def test_runner(test_agent: Agent) -> TestInMemoryRunner: @pytest.fixture -def mock_start_as_current_span(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: - mock_context_manager = mock.MagicMock() - mock_context_manager.__enter__.return_value = mock.Mock() - mock_start_as_current_span = mock.Mock() - mock_start_as_current_span.return_value = mock_context_manager +def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter: + tracer_provider = TracerProvider() + span_exporter = InMemorySpanExporter() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + real_tracer = tracer_provider.get_tracer(__name__) def do_replace(tracer): monkeypatch.setattr( - tracer, 'start_as_current_span', mock_start_as_current_span + tracer, 'start_as_current_span', real_tracer.start_as_current_span ) do_replace(tracing.tracer) do_replace(base_agent.tracer) - return mock_start_as_current_span + return span_exporter @pytest.mark.asyncio async def test_tracer_start_as_current_span( test_runner: TestInMemoryRunner, - mock_start_as_current_span: mock.Mock, + span_exporter: InMemorySpanExporter, ): """Test creation of multiple spans in an E2E runner invocation. @@ -112,18 +116,49 @@ async def test_tracer_start_as_current_span( pass # Assert - expected_start_as_current_span_calls = [ - mock.call('invocation'), - mock.call('execute_tool some_tool'), - mock.call('invoke_agent some_root_agent'), - mock.call('call_llm'), - mock.call('call_llm'), + spans = span_exporter.get_finished_spans() + assert list(sorted(span.name for span in spans)) == [ + 'call_llm', + 'call_llm', + 'execute_tool some_tool', + 'invocation', + 'invoke_agent some_root_agent', ] - mock_start_as_current_span.assert_has_calls( - expected_start_as_current_span_calls, - any_order=True, + +@pytest.mark.asyncio +async def test_exception_preserves_attributes( + test_model: BaseLlm, span_exporter: InMemorySpanExporter +): + """Test when an exception occurs during tool execution, span attributes are still present on spans where they are expected.""" + + # Arrange + async def some_tool(): + raise ValueError('This tool always fails') + + test_agent = Agent( + name='some_root_agent', + model=test_model, + tools=[ + FunctionTool(some_tool), + ], ) - assert mock_start_as_current_span.call_count == len( - expected_start_as_current_span_calls + + test_runner = TestInMemoryRunner(test_agent) + + # Act + with pytest.raises(ValueError, match='This tool always fails'): + async with Aclosing( + test_runner.run_async_with_new_session_agen('') + ) as agen: + async for _ in agen: + pass + + # Assert + spans = span_exporter.get_finished_spans() + assert len(spans) > 1 + assert all( + span.attributes is not None and len(span.attributes) > 0 + for span in spans + if span.name != 'invocation' # not expected to have attributes )