You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: pass trace context in MCP tool call's _meta field with Otel propagator
PiperOrigin-RevId: 868841079
This commit is contained in:
committed by
Copybara-Service
parent
9dccd6a692
commit
bcbfeba953
@@ -31,6 +31,7 @@ from fastapi.openapi.models import APIKeyIn
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from mcp.shared.session import ProgressFnT
|
||||
from mcp.types import Tool as McpBaseTool
|
||||
from opentelemetry import propagate
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.callback_context import CallbackContext
|
||||
@@ -313,6 +314,12 @@ class McpTool(BaseAuthenticatedTool):
|
||||
headers.update(dynamic_headers)
|
||||
final_headers = headers if headers else None
|
||||
|
||||
# Propagate trace context in the _meta field as sprcified by MCP protocol.
|
||||
# See https://agentclientprotocol.com/protocol/extensibility#the-meta-field
|
||||
trace_carrier: Dict[str, str] = {}
|
||||
propagate.get_global_textmap().inject(carrier=trace_carrier)
|
||||
meta_trace_context = trace_carrier if trace_carrier else None
|
||||
|
||||
# Get the session from the session manager
|
||||
session = await self._mcp_session_manager.create_session(
|
||||
headers=final_headers
|
||||
@@ -325,6 +332,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
self._mcp_tool.name,
|
||||
arguments=args,
|
||||
progress_callback=resolved_callback,
|
||||
meta=meta_trace_context,
|
||||
)
|
||||
return response.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.features import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.tools.mcp_tool import mcp_tool
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
|
||||
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
@@ -225,7 +226,7 @@ class TestMCPTool:
|
||||
)
|
||||
# Fix: call_tool uses 'arguments' parameter, not positional args
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args, progress_callback=None
|
||||
"test_tool", arguments=args, progress_callback=None, meta=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -262,6 +263,55 @@ class TestMCPTool:
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {"Authorization": "Bearer test_access_token"}
|
||||
|
||||
@patch.object(mcp_tool, "propagate", autospec=True)
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_impl_with_trace_context(self, mock_propagate):
|
||||
"""Test running tool with trace context injection."""
|
||||
mock_propagator = Mock()
|
||||
|
||||
def inject_context(carrier, context=None) -> None:
|
||||
carrier["traceparent"] = (
|
||||
"00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"
|
||||
)
|
||||
carrier["tracestate"] = "foo=bar"
|
||||
carrier["baggage"] = "baz=qux"
|
||||
|
||||
mock_propagator.inject.side_effect = inject_context
|
||||
mock_propagate.get_global_textmap.return_value = mock_propagator
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
args = {"param1": "test_value"}
|
||||
|
||||
await tool._run_async_impl(
|
||||
args=args, tool_context=tool_context, credential=None
|
||||
)
|
||||
|
||||
self.mock_session_manager.create_session.assert_called_once_with(
|
||||
headers=None
|
||||
)
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool",
|
||||
arguments=args,
|
||||
progress_callback=None,
|
||||
meta={
|
||||
"traceparent": (
|
||||
"00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"
|
||||
),
|
||||
"tracestate": "foo=bar",
|
||||
"baggage": "baz=qux",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_oauth2(self):
|
||||
"""Test header generation for OAuth2 credentials."""
|
||||
@@ -778,7 +828,7 @@ class TestMCPTool:
|
||||
headers=expected_headers
|
||||
)
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args, progress_callback=None
|
||||
"test_tool", arguments=args, progress_callback=None, meta=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -821,7 +871,7 @@ class TestMCPTool:
|
||||
"X-Tenant-ID": "test-tenant",
|
||||
}
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args, progress_callback=None
|
||||
"test_tool", arguments=args, progress_callback=None, meta=None
|
||||
)
|
||||
|
||||
def test_init_with_progress_callback(self):
|
||||
@@ -875,7 +925,10 @@ class TestMCPTool:
|
||||
)
|
||||
# Verify progress_callback was passed to call_tool
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args, progress_callback=my_progress_callback
|
||||
"test_tool",
|
||||
arguments=args,
|
||||
progress_callback=my_progress_callback,
|
||||
meta=None,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user