You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: set execute_tool {tool.name} span attributes even when exception occurs during tool's execution
PiperOrigin-RevId: 824165197
This commit is contained in:
committed by
Copybara-Service
parent
5d9a7e7f79
commit
e6f7363220
@@ -305,7 +305,9 @@ async def _execute_single_function_call_async(
|
||||
else:
|
||||
raise tool_error
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
async def _run_with_trace():
|
||||
nonlocal function_args
|
||||
|
||||
# Step 1: Check if plugin before_tool_callback overrides the function
|
||||
# response.
|
||||
function_response = (
|
||||
@@ -391,13 +393,23 @@ async def _execute_single_function_call_async(
|
||||
function_response_event = __build_response_event(
|
||||
tool, function_response, tool_context, invocation_context
|
||||
)
|
||||
trace_tool_call(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
try:
|
||||
function_response_event = await _run_with_trace()
|
||||
trace_tool_call(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
except:
|
||||
trace_tool_call(
|
||||
tool=tool, args=function_args, function_response_event=None
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def handle_function_calls_live(
|
||||
invocation_context: InvocationContext,
|
||||
@@ -467,13 +479,17 @@ async def _execute_single_function_call_live(
|
||||
tool, tool_context = _get_tool_and_context(
|
||||
invocation_context, function_call, tools_dict
|
||||
)
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
|
||||
function_args = (
|
||||
copy.deepcopy(function_call.args) if function_call.args else {}
|
||||
)
|
||||
|
||||
async def _run_with_trace():
|
||||
nonlocal function_args
|
||||
|
||||
# Do not use "args" as the variable name, because it is a reserved keyword
|
||||
# in python debugger.
|
||||
# Make a deep copy to avoid being modified.
|
||||
function_args = (
|
||||
copy.deepcopy(function_call.args) if function_call.args else {}
|
||||
)
|
||||
function_response = None
|
||||
|
||||
# Handle before_tool_callbacks - iterate through the canonical callback
|
||||
@@ -527,13 +543,23 @@ async def _execute_single_function_call_live(
|
||||
function_response_event = __build_response_event(
|
||||
tool, function_response, tool_context, invocation_context
|
||||
)
|
||||
trace_tool_call(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
|
||||
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
|
||||
try:
|
||||
function_response_event = await _run_with_trace()
|
||||
trace_tool_call(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
function_response_event=function_response_event,
|
||||
)
|
||||
return function_response_event
|
||||
except:
|
||||
trace_tool_call(
|
||||
tool=tool, args=function_args, function_response_event=None
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
async def _process_function_live_helper(
|
||||
tool,
|
||||
|
||||
@@ -26,6 +26,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
@@ -118,7 +119,7 @@ def trace_agent_invocation(
|
||||
def trace_tool_call(
|
||||
tool: BaseTool,
|
||||
args: dict[str, Any],
|
||||
function_response_event: Event,
|
||||
function_response_event: Optional[Event],
|
||||
):
|
||||
"""Traces tool call.
|
||||
|
||||
@@ -154,7 +155,8 @@ def trace_tool_call(
|
||||
tool_call_id = '<not specified>'
|
||||
tool_response = '<not specified>'
|
||||
if (
|
||||
function_response_event.content is not None
|
||||
function_response_event is not None
|
||||
and function_response_event.content is not None
|
||||
and function_response_event.content.parts
|
||||
):
|
||||
response_parts = function_response_event.content.parts
|
||||
@@ -169,7 +171,8 @@ def trace_tool_call(
|
||||
|
||||
if not isinstance(tool_response, dict):
|
||||
tool_response = {'result': tool_response}
|
||||
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
|
||||
if function_response_event is not None:
|
||||
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
|
||||
if _should_add_request_response_to_spans():
|
||||
span.set_attribute(
|
||||
'gcp.vertex.agent.tool_response',
|
||||
|
||||
@@ -12,18 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.agents import base_agent
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models.base_llm import BaseLlm
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.telemetry import tracing
|
||||
from google.adk.tools import FunctionTool
|
||||
from google.adk.utils.context_utils import Aclosing
|
||||
from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
from opentelemetry.version import __version__
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
import pytest
|
||||
|
||||
from ..testing_utils import MockModel
|
||||
@@ -63,27 +67,27 @@ async def test_runner(test_agent: Agent) -> TestInMemoryRunner:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_start_as_current_span(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
|
||||
mock_context_manager = mock.MagicMock()
|
||||
mock_context_manager.__enter__.return_value = mock.Mock()
|
||||
mock_start_as_current_span = mock.Mock()
|
||||
mock_start_as_current_span.return_value = mock_context_manager
|
||||
def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter:
|
||||
tracer_provider = TracerProvider()
|
||||
span_exporter = InMemorySpanExporter()
|
||||
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter))
|
||||
real_tracer = tracer_provider.get_tracer(__name__)
|
||||
|
||||
def do_replace(tracer):
|
||||
monkeypatch.setattr(
|
||||
tracer, 'start_as_current_span', mock_start_as_current_span
|
||||
tracer, 'start_as_current_span', real_tracer.start_as_current_span
|
||||
)
|
||||
|
||||
do_replace(tracing.tracer)
|
||||
do_replace(base_agent.tracer)
|
||||
|
||||
return mock_start_as_current_span
|
||||
return span_exporter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracer_start_as_current_span(
|
||||
test_runner: TestInMemoryRunner,
|
||||
mock_start_as_current_span: mock.Mock,
|
||||
span_exporter: InMemorySpanExporter,
|
||||
):
|
||||
"""Test creation of multiple spans in an E2E runner invocation.
|
||||
|
||||
@@ -112,18 +116,49 @@ async def test_tracer_start_as_current_span(
|
||||
pass
|
||||
|
||||
# Assert
|
||||
expected_start_as_current_span_calls = [
|
||||
mock.call('invocation'),
|
||||
mock.call('execute_tool some_tool'),
|
||||
mock.call('invoke_agent some_root_agent'),
|
||||
mock.call('call_llm'),
|
||||
mock.call('call_llm'),
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert list(sorted(span.name for span in spans)) == [
|
||||
'call_llm',
|
||||
'call_llm',
|
||||
'execute_tool some_tool',
|
||||
'invocation',
|
||||
'invoke_agent some_root_agent',
|
||||
]
|
||||
|
||||
mock_start_as_current_span.assert_has_calls(
|
||||
expected_start_as_current_span_calls,
|
||||
any_order=True,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_preserves_attributes(
|
||||
test_model: BaseLlm, span_exporter: InMemorySpanExporter
|
||||
):
|
||||
"""Test when an exception occurs during tool execution, span attributes are still present on spans where they are expected."""
|
||||
|
||||
# Arrange
|
||||
async def some_tool():
|
||||
raise ValueError('This tool always fails')
|
||||
|
||||
test_agent = Agent(
|
||||
name='some_root_agent',
|
||||
model=test_model,
|
||||
tools=[
|
||||
FunctionTool(some_tool),
|
||||
],
|
||||
)
|
||||
assert mock_start_as_current_span.call_count == len(
|
||||
expected_start_as_current_span_calls
|
||||
|
||||
test_runner = TestInMemoryRunner(test_agent)
|
||||
|
||||
# Act
|
||||
with pytest.raises(ValueError, match='This tool always fails'):
|
||||
async with Aclosing(
|
||||
test_runner.run_async_with_new_session_agen('')
|
||||
) as agen:
|
||||
async for _ in agen:
|
||||
pass
|
||||
|
||||
# Assert
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) > 1
|
||||
assert all(
|
||||
span.attributes is not None and len(span.attributes) > 0
|
||||
for span in spans
|
||||
if span.name != 'invocation' # not expected to have attributes
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user