feat: pass trace context in MCP tool call's _meta field with Otel propagator

PiperOrigin-RevId: 868841079
This commit is contained in:
Leon Ziyang Zhang
2026-02-11 14:17:57 -08:00
committed by Copybara-Service
parent 9dccd6a692
commit bcbfeba953
2 changed files with 65 additions and 4 deletions
@@ -31,6 +31,7 @@ from fastapi.openapi.models import APIKeyIn
from google.genai.types import FunctionDeclaration
from mcp.shared.session import ProgressFnT
from mcp.types import Tool as McpBaseTool
from opentelemetry import propagate
from typing_extensions import override
from ...agents.callback_context import CallbackContext
@@ -313,6 +314,12 @@ class McpTool(BaseAuthenticatedTool):
headers.update(dynamic_headers)
final_headers = headers if headers else None
# Propagate trace context in the _meta field as sprcified by MCP protocol.
# See https://agentclientprotocol.com/protocol/extensibility#the-meta-field
trace_carrier: Dict[str, str] = {}
propagate.get_global_textmap().inject(carrier=trace_carrier)
meta_trace_context = trace_carrier if trace_carrier else None
# Get the session from the session manager
session = await self._mcp_session_manager.create_session(
headers=final_headers
@@ -325,6 +332,7 @@ class McpTool(BaseAuthenticatedTool):
self._mcp_tool.name,
arguments=args,
progress_callback=resolved_callback,
meta=meta_trace_context,
)
return response.model_dump(exclude_none=True, mode="json")
@@ -25,6 +25,7 @@ from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.mcp_tool import mcp_tool
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.tool_context import ToolContext
@@ -225,7 +226,7 @@ class TestMCPTool:
)
# Fix: call_tool uses 'arguments' parameter, not positional args
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args, progress_callback=None
"test_tool", arguments=args, progress_callback=None, meta=None
)
@pytest.mark.asyncio
@@ -262,6 +263,55 @@ class TestMCPTool:
headers = call_args[1]["headers"]
assert headers == {"Authorization": "Bearer test_access_token"}
@patch.object(mcp_tool, "propagate", autospec=True)
@pytest.mark.asyncio
async def test_run_async_impl_with_trace_context(self, mock_propagate):
"""Test running tool with trace context injection."""
mock_propagator = Mock()
def inject_context(carrier, context=None) -> None:
carrier["traceparent"] = (
"00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"
)
carrier["tracestate"] = "foo=bar"
carrier["baggage"] = "baz=qux"
mock_propagator.inject.side_effect = inject_context
mock_propagate.get_global_textmap.return_value = mock_propagator
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
tool_context = Mock(spec=ToolContext)
args = {"param1": "test_value"}
await tool._run_async_impl(
args=args, tool_context=tool_context, credential=None
)
self.mock_session_manager.create_session.assert_called_once_with(
headers=None
)
self.mock_session.call_tool.assert_called_once_with(
"test_tool",
arguments=args,
progress_callback=None,
meta={
"traceparent": (
"00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"
),
"tracestate": "foo=bar",
"baggage": "baz=qux",
},
)
@pytest.mark.asyncio
async def test_get_headers_oauth2(self):
"""Test header generation for OAuth2 credentials."""
@@ -778,7 +828,7 @@ class TestMCPTool:
headers=expected_headers
)
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args, progress_callback=None
"test_tool", arguments=args, progress_callback=None, meta=None
)
@pytest.mark.asyncio
@@ -821,7 +871,7 @@ class TestMCPTool:
"X-Tenant-ID": "test-tenant",
}
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args, progress_callback=None
"test_tool", arguments=args, progress_callback=None, meta=None
)
def test_init_with_progress_callback(self):
@@ -875,7 +925,10 @@ class TestMCPTool:
)
# Verify progress_callback was passed to call_tool
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args, progress_callback=my_progress_callback
"test_tool",
arguments=args,
progress_callback=my_progress_callback,
meta=None,
)
@pytest.mark.asyncio