You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add progress_callback support to MCPTool and MCPToolset
Fixes: https://github.com/google/adk-python/issues/3811 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 866025995
This commit is contained in:
committed by
Copybara-Service
parent
9b112e2d13
commit
adbc37fea1
@@ -0,0 +1,15 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,166 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sample agent demonstrating MCP progress callback feature.
|
||||
|
||||
This sample shows how to use the progress_callback parameter in McpToolset
|
||||
to receive progress notifications from MCP servers during long-running tool
|
||||
executions.
|
||||
|
||||
There are two ways to use progress callbacks:
|
||||
|
||||
1. Simple callback (shared by all tools):
|
||||
Pass a ProgressFnT callback that receives (progress, total, message).
|
||||
|
||||
2. Factory function (per-tool callbacks with runtime context):
|
||||
Pass a ProgressCallbackFactory that takes (tool_name, callback_context, **kwargs)
|
||||
and returns a ProgressFnT or None. This allows different tools to have different
|
||||
progress handling logic, and the factory can access and modify session state
|
||||
via the CallbackContext. The **kwargs ensures forward compatibility for future
|
||||
parameters.
|
||||
|
||||
IMPORTANT: Progress callbacks only work when the MCP server actually sends
|
||||
progress notifications. Most simple MCP servers (like the filesystem server)
|
||||
do not send progress updates. This sample uses a mock server that demonstrates
|
||||
progress reporting.
|
||||
|
||||
Usage:
|
||||
adk run contributing/samples/mcp_progress_callback_agent
|
||||
|
||||
Then try:
|
||||
"Run the long running task with 5 steps"
|
||||
"Process these items: apple, banana, cherry"
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from google.adk.agents.callback_context import CallbackContext
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool import StdioConnectionParams
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.shared.session import ProgressFnT
|
||||
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
_mock_server_path = os.path.join(_current_dir, "mock_progress_server.py")
|
||||
|
||||
|
||||
# Option 1: Simple shared callback
|
||||
async def simple_progress_callback(
|
||||
progress: float,
|
||||
total: float | None,
|
||||
message: str | None,
|
||||
) -> None:
|
||||
"""Handle progress notifications from MCP server.
|
||||
|
||||
This callback is shared by all tools in the toolset.
|
||||
"""
|
||||
if total is not None:
|
||||
percentage = (progress / total) * 100
|
||||
bar_length = 20
|
||||
filled = int(bar_length * progress / total)
|
||||
bar = "=" * filled + "-" * (bar_length - filled)
|
||||
print(f"[{bar}] {percentage:.0f}% ({progress}/{total}) {message or ''}")
|
||||
else:
|
||||
print(f"Progress: {progress} {f'- {message}' if message else ''}")
|
||||
|
||||
|
||||
# Option 2: Factory function for per-tool callbacks with runtime context
|
||||
def progress_callback_factory(
|
||||
tool_name: str,
|
||||
*,
|
||||
callback_context: CallbackContext | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ProgressFnT | None:
|
||||
"""Create a progress callback for a specific tool.
|
||||
|
||||
This factory allows different tools to have different progress handling.
|
||||
It receives a CallbackContext for accessing and modifying runtime information
|
||||
like session state. The **kwargs parameter ensures forward compatibility.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the MCP tool.
|
||||
callback_context: The callback context providing access to session,
|
||||
state, artifacts, and other runtime information. Allows modifying
|
||||
state via ctx.state['key'] = value. May be None if not available.
|
||||
**kwargs: Additional keyword arguments for future extensibility.
|
||||
|
||||
Returns:
|
||||
A progress callback function, or None if no callback is needed.
|
||||
"""
|
||||
# Example: Access session info from context (if available)
|
||||
session_id = "unknown"
|
||||
if callback_context and callback_context.session:
|
||||
session_id = callback_context.session.id
|
||||
|
||||
async def callback(
|
||||
progress: float,
|
||||
total: float | None,
|
||||
message: str | None,
|
||||
) -> None:
|
||||
# Include tool name and session info in the progress output
|
||||
prefix = f"[{tool_name}][session:{session_id}]"
|
||||
if total is not None:
|
||||
percentage = (progress / total) * 100
|
||||
bar_length = 20
|
||||
filled = int(bar_length * progress / total)
|
||||
bar = "=" * filled + "-" * (bar_length - filled)
|
||||
print(f"{prefix} [{bar}] {percentage:.0f}% {message or ''}")
|
||||
# Example: Store progress in state (callback_context allows modification)
|
||||
if callback_context:
|
||||
callback_context.state["last_progress"] = progress
|
||||
callback_context.state["last_total"] = total
|
||||
else:
|
||||
print(
|
||||
f"{prefix} Progress: {progress} {f'- {message}' if message else ''}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
root_agent = LlmAgent(
|
||||
model="gemini-2.5-flash",
|
||||
name="progress_demo_agent",
|
||||
instruction="""\
|
||||
You are a helpful assistant that can run long-running tasks.
|
||||
|
||||
Available tools:
|
||||
- long_running_task: Simulates a task with multiple steps. You can specify
|
||||
the number of steps and delay between them.
|
||||
- process_items: Processes a list of items one by one with progress updates.
|
||||
|
||||
When the user asks you to run a task, use these tools and the progress
|
||||
will be logged automatically.
|
||||
|
||||
Example requests:
|
||||
- "Run a long task with 5 steps"
|
||||
- "Process these items: apple, banana, cherry, date"
|
||||
""",
|
||||
tools=[
|
||||
McpToolset(
|
||||
connection_params=StdioConnectionParams(
|
||||
server_params=StdioServerParameters(
|
||||
command=sys.executable, # Use current Python interpreter
|
||||
args=[_mock_server_path],
|
||||
),
|
||||
timeout=60,
|
||||
),
|
||||
# Use factory function for per-tool callbacks (Option 2)
|
||||
# Or use simple_progress_callback for shared callback (Option 1)
|
||||
progress_callback=progress_callback_factory,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,161 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Mock MCP server that sends progress notifications.
|
||||
|
||||
This server demonstrates how MCP servers can send progress updates
|
||||
during long-running tool execution.
|
||||
|
||||
Run this server directly:
|
||||
python mock_progress_server.py
|
||||
|
||||
Or use it with the sample agent:
|
||||
See agent_with_mock_server.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent
|
||||
from mcp.types import Tool
|
||||
|
||||
server = Server("mock-progress-server")
|
||||
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
"""List available tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="long_running_task",
|
||||
description=(
|
||||
"A simulated long-running task that reports progress. "
|
||||
"Use this to test progress callback functionality."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": {
|
||||
"type": "integer",
|
||||
"description": "Number of steps to simulate (default: 5)",
|
||||
"default": 5,
|
||||
},
|
||||
"delay": {
|
||||
"type": "number",
|
||||
"description": (
|
||||
"Delay in seconds between steps (default: 0.5)"
|
||||
),
|
||||
"default": 0.5,
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="process_items",
|
||||
description="Process a list of items with progress reporting.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of items to process",
|
||||
},
|
||||
},
|
||||
"required": ["items"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""Handle tool calls with progress reporting."""
|
||||
ctx = server.request_context
|
||||
|
||||
if name == "long_running_task":
|
||||
steps = arguments.get("steps", 5)
|
||||
delay = arguments.get("delay", 0.5)
|
||||
|
||||
# Get progress token from request metadata
|
||||
progress_token = None
|
||||
if ctx.meta and hasattr(ctx.meta, "progressToken"):
|
||||
progress_token = ctx.meta.progressToken
|
||||
|
||||
for i in range(steps):
|
||||
# Simulate work
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Send progress notification if client supports it
|
||||
if progress_token is not None:
|
||||
await ctx.session.send_progress_notification(
|
||||
progress_token=progress_token,
|
||||
progress=i + 1,
|
||||
total=steps,
|
||||
message=f"Completed step {i + 1} of {steps}",
|
||||
)
|
||||
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text=f"Successfully completed {steps} steps!",
|
||||
)
|
||||
]
|
||||
|
||||
elif name == "process_items":
|
||||
items = arguments.get("items", [])
|
||||
total = len(items)
|
||||
|
||||
progress_token = None
|
||||
if ctx.meta and hasattr(ctx.meta, "progressToken"):
|
||||
progress_token = ctx.meta.progressToken
|
||||
|
||||
results = []
|
||||
for i, item in enumerate(items):
|
||||
# Simulate processing
|
||||
await asyncio.sleep(0.3)
|
||||
results.append(f"Processed: {item}")
|
||||
|
||||
# Send progress
|
||||
if progress_token is not None:
|
||||
await ctx.session.send_progress_notification(
|
||||
progress_token=progress_token,
|
||||
progress=i + 1,
|
||||
total=total,
|
||||
message=f"Processing item: {item}",
|
||||
)
|
||||
|
||||
return [
|
||||
TextContent(
|
||||
type="text",
|
||||
text="\n".join(results),
|
||||
)
|
||||
]
|
||||
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run the MCP server."""
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
server.create_initialization_options(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import logging
|
||||
@@ -21,14 +22,18 @@ from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import runtime_checkable
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from mcp.shared.session import ProgressFnT
|
||||
from mcp.types import Tool as McpBaseTool
|
||||
from typing_extensions import override
|
||||
|
||||
from ...agents.callback_context import CallbackContext
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
@@ -37,7 +42,6 @@ from ...features import FeatureName
|
||||
from ...features import is_feature_enabled
|
||||
from .._gemini_schema_util import _to_gemini_schema
|
||||
from ..base_authenticated_tool import BaseAuthenticatedTool
|
||||
# import
|
||||
from ..tool_context import ToolContext
|
||||
from .mcp_session_manager import MCPSessionManager
|
||||
from .mcp_session_manager import retry_on_errors
|
||||
@@ -45,6 +49,68 @@ from .mcp_session_manager import retry_on_errors
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProgressCallbackFactory(Protocol):
|
||||
"""Factory protocol for creating per-tool progress callbacks.
|
||||
|
||||
This protocol allows users to create different progress callbacks for
|
||||
different tools based on tool name and runtime context. The factory receives
|
||||
the tool name, a CallbackContext for accessing and modifying session state,
|
||||
and additional keyword arguments for forward compatibility.
|
||||
|
||||
Example usage::
|
||||
|
||||
def my_callback_factory(
|
||||
tool_name: str,
|
||||
*,
|
||||
callback_context: CallbackContext | None = None,
|
||||
**kwargs
|
||||
) -> ProgressFnT | None:
|
||||
session_id = callback_context.session.id if callback_context else "N/A"
|
||||
|
||||
async def callback(progress, total, message):
|
||||
print(f"[{tool_name}] Session {session_id}: {progress}/{total}")
|
||||
# Can modify state in the callback
|
||||
if callback_context:
|
||||
callback_context.state['last_progress'] = progress
|
||||
|
||||
return callback
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=...,
|
||||
progress_callback=my_callback_factory,
|
||||
)
|
||||
|
||||
Note:
|
||||
The **kwargs parameter is required for forward compatibility. Future
|
||||
versions may pass additional parameters. Implementations should accept
|
||||
**kwargs even if they don't use them.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tool_name: str,
|
||||
*,
|
||||
callback_context: Optional[CallbackContext] = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[ProgressFnT]:
|
||||
"""Create a progress callback for a specific tool.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the MCP tool.
|
||||
callback_context: The callback context providing access to session,
|
||||
state, artifacts, and other runtime information. Allows modifying
|
||||
state via ctx.state['key'] = value. May be None if not available.
|
||||
**kwargs: Additional keyword arguments for future extensibility.
|
||||
Implementations should accept **kwargs for forward compatibility.
|
||||
|
||||
Returns:
|
||||
A progress callback function, or None if no callback is needed
|
||||
for this tool.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class McpTool(BaseAuthenticatedTool):
|
||||
"""Turns an MCP Tool into an ADK Tool.
|
||||
|
||||
@@ -66,6 +132,9 @@ class McpTool(BaseAuthenticatedTool):
|
||||
header_provider: Optional[
|
||||
Callable[[ReadonlyContext], Dict[str, str]]
|
||||
] = None,
|
||||
progress_callback: Optional[
|
||||
Union[ProgressFnT, ProgressCallbackFactory]
|
||||
] = None,
|
||||
):
|
||||
"""Initializes an McpTool.
|
||||
|
||||
@@ -81,6 +150,17 @@ class McpTool(BaseAuthenticatedTool):
|
||||
or a callable that takes the function's arguments and returns a
|
||||
boolean. If the callable returns True, the tool will require
|
||||
confirmation from the user.
|
||||
header_provider: Optional function to provide dynamic headers.
|
||||
progress_callback: Optional callback to receive progress notifications
|
||||
from MCP server during long-running tool execution. Can be either:
|
||||
|
||||
- A ``ProgressFnT`` callback that receives (progress, total, message).
|
||||
This callback will be used for all invocations.
|
||||
|
||||
- A ``ProgressCallbackFactory`` that creates per-invocation callbacks.
|
||||
The factory receives (tool_name, callback_context, **kwargs) and
|
||||
returns a ProgressFnT or None. This allows callbacks to access
|
||||
and modify runtime context like session state.
|
||||
|
||||
Raises:
|
||||
ValueError: If mcp_tool or mcp_session_manager is None.
|
||||
@@ -98,6 +178,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
self._mcp_session_manager = mcp_session_manager
|
||||
self._require_confirmation = require_confirmation
|
||||
self._header_provider = header_provider
|
||||
self._progress_callback = progress_callback
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
@@ -237,9 +318,50 @@ class McpTool(BaseAuthenticatedTool):
|
||||
headers=final_headers
|
||||
)
|
||||
|
||||
response = await session.call_tool(self._mcp_tool.name, arguments=args)
|
||||
# Resolve progress callback (may be a factory that needs runtime context)
|
||||
resolved_callback = self._resolve_progress_callback(tool_context)
|
||||
|
||||
response = await session.call_tool(
|
||||
self._mcp_tool.name,
|
||||
arguments=args,
|
||||
progress_callback=resolved_callback,
|
||||
)
|
||||
return response.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
def _resolve_progress_callback(
|
||||
self, tool_context: ToolContext
|
||||
) -> Optional[ProgressFnT]:
|
||||
"""Resolve the progress callback for the current invocation.
|
||||
|
||||
If progress_callback is a ProgressCallbackFactory, call it to create
|
||||
a callback with runtime context. Otherwise, return the callback directly.
|
||||
|
||||
Args:
|
||||
tool_context: The tool context for the current invocation.
|
||||
|
||||
Returns:
|
||||
The resolved progress callback, or None if not configured.
|
||||
"""
|
||||
if (
|
||||
not hasattr(self, "_progress_callback")
|
||||
or self._progress_callback is None
|
||||
):
|
||||
return None
|
||||
|
||||
# Determine if callback is a factory by checking if it's a coroutine
|
||||
# function. ProgressFnT is an async function, while ProgressCallbackFactory
|
||||
# is a sync function that returns an async function.
|
||||
if asyncio.iscoroutinefunction(self._progress_callback):
|
||||
return self._progress_callback
|
||||
|
||||
# If it's a regular callable (not async), treat it as a factory
|
||||
if callable(self._progress_callback) and not inspect.iscoroutinefunction(
|
||||
self._progress_callback
|
||||
):
|
||||
return self._progress_callback(self.name, callback_context=tool_context)
|
||||
|
||||
return self._progress_callback
|
||||
|
||||
async def _get_headers(
|
||||
self, tool_context: ToolContext, credential: AuthCredential
|
||||
) -> Optional[dict[str, str]]:
|
||||
@@ -253,7 +375,8 @@ class McpTool(BaseAuthenticatedTool):
|
||||
Dictionary of headers to add to the request, or None if no auth.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key authentication is configured for non-header location.
|
||||
ValueError: If API key authentication is configured for non-header
|
||||
location.
|
||||
"""
|
||||
headers: Optional[dict[str, str]] = None
|
||||
if credential:
|
||||
@@ -284,7 +407,8 @@ class McpTool(BaseAuthenticatedTool):
|
||||
# Handle other HTTP schemes with token
|
||||
headers = {
|
||||
"Authorization": (
|
||||
f"{credential.http.scheme} {credential.http.credentials.token}"
|
||||
f"{credential.http.scheme}"
|
||||
f" {credential.http.credentials.token}"
|
||||
)
|
||||
}
|
||||
elif credential.api_key:
|
||||
|
||||
@@ -31,6 +31,7 @@ from typing import Union
|
||||
import warnings
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.shared.session import ProgressFnT
|
||||
from mcp.types import ListResourcesResult
|
||||
from mcp.types import ListToolsResult
|
||||
from pydantic import model_validator
|
||||
@@ -51,6 +52,7 @@ from .mcp_session_manager import SseConnectionParams
|
||||
from .mcp_session_manager import StdioConnectionParams
|
||||
from .mcp_session_manager import StreamableHTTPConnectionParams
|
||||
from .mcp_tool import MCPTool
|
||||
from .mcp_tool import ProgressCallbackFactory
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
@@ -72,7 +74,8 @@ class McpToolset(BaseToolset):
|
||||
command='npx',
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
),
|
||||
tool_filter=['read_file', 'list_directory'] # Optional: filter specific tools
|
||||
tool_filter=['read_file', 'list_directory'] # Optional: filter specific
|
||||
tools
|
||||
)
|
||||
|
||||
# Use in an agent
|
||||
@@ -106,18 +109,21 @@ class McpToolset(BaseToolset):
|
||||
header_provider: Optional[
|
||||
Callable[[ReadonlyContext], Dict[str, str]]
|
||||
] = None,
|
||||
progress_callback: Optional[
|
||||
Union[ProgressFnT, ProgressCallbackFactory]
|
||||
] = None,
|
||||
):
|
||||
"""Initializes the McpToolset.
|
||||
|
||||
Args:
|
||||
connection_params: The connection parameters to the MCP server. Can be:
|
||||
``StdioConnectionParams`` for using local mcp server (e.g. using ``npx`` or
|
||||
``python3``); or ``SseConnectionParams`` for a local/remote SSE server; or
|
||||
``StreamableHTTPConnectionParams`` for local/remote Streamable http
|
||||
server. Note, ``StdioServerParameters`` is also supported for using local
|
||||
mcp server (e.g. using ``npx`` or ``python3`` ), but it does not support
|
||||
timeout, and we recommend to use ``StdioConnectionParams`` instead when
|
||||
timeout is needed.
|
||||
``StdioConnectionParams`` for using local mcp server (e.g. using ``npx``
|
||||
or ``python3``); or ``SseConnectionParams`` for a local/remote SSE
|
||||
server; or ``StreamableHTTPConnectionParams`` for local/remote
|
||||
Streamable http server. Note, ``StdioServerParameters`` is also
|
||||
supported for using local mcp server (e.g. using ``npx`` or ``python3``
|
||||
), but it does not support timeout, and we recommend to use
|
||||
``StdioConnectionParams`` instead when timeout is needed.
|
||||
tool_filter: Optional filter to select specific tools. Can be either: - A
|
||||
list of tool names to include - A ToolPredicate function for custom
|
||||
filtering logic
|
||||
@@ -126,11 +132,22 @@ class McpToolset(BaseToolset):
|
||||
errlog: TextIO stream for error logging.
|
||||
auth_scheme: The auth scheme of the tool for tool calling
|
||||
auth_credential: The auth credential of the tool for tool calling
|
||||
require_confirmation: Whether tools in this toolset require
|
||||
confirmation. Can be a single boolean or a callable to apply to all
|
||||
tools.
|
||||
require_confirmation: Whether tools in this toolset require confirmation.
|
||||
Can be a single boolean or a callable to apply to all tools.
|
||||
header_provider: A callable that takes a ReadonlyContext and returns a
|
||||
dictionary of headers to be used for the MCP session.
|
||||
progress_callback: Optional callback to receive progress notifications
|
||||
from MCP server during long-running tool execution. Can be either:
|
||||
|
||||
- A ``ProgressFnT`` callback that receives (progress, total, message).
|
||||
This callback will be shared by all tools in the toolset.
|
||||
|
||||
- A ``ProgressCallbackFactory`` that creates per-tool callbacks. The
|
||||
factory receives (tool_name, callback_context, **kwargs) and returns
|
||||
a ProgressFnT or None. This allows different tools to have different
|
||||
progress handling logic and access/modify session state via the
|
||||
CallbackContext. The **kwargs parameter allows for future
|
||||
extensibility.
|
||||
"""
|
||||
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
|
||||
|
||||
@@ -140,6 +157,7 @@ class McpToolset(BaseToolset):
|
||||
self._connection_params = connection_params
|
||||
self._errlog = errlog
|
||||
self._header_provider = header_provider
|
||||
self._progress_callback = progress_callback
|
||||
|
||||
# Create the session manager that will handle the MCP connection
|
||||
self._mcp_session_manager = MCPSessionManager(
|
||||
@@ -270,7 +288,7 @@ class McpToolset(BaseToolset):
|
||||
|
||||
Args:
|
||||
readonly_context: Context used to filter tools available to the agent.
|
||||
If None, all tools in the toolset are returned.
|
||||
If None, all tools in the toolset are returned.
|
||||
|
||||
Returns:
|
||||
List[BaseTool]: A list of tools available under the specified context.
|
||||
@@ -292,6 +310,9 @@ class McpToolset(BaseToolset):
|
||||
auth_credential=self._auth_credential,
|
||||
require_confirmation=self._require_confirmation,
|
||||
header_provider=self._header_provider,
|
||||
progress_callback=self._progress_callback
|
||||
if hasattr(self, "_progress_callback")
|
||||
else None,
|
||||
)
|
||||
|
||||
if self._is_tool_selected(mcp_tool, readonly_context):
|
||||
|
||||
@@ -225,7 +225,7 @@ class TestMCPTool:
|
||||
)
|
||||
# Fix: call_tool uses 'arguments' parameter, not positional args
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args
|
||||
"test_tool", arguments=args, progress_callback=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -778,7 +778,7 @@ class TestMCPTool:
|
||||
headers=expected_headers
|
||||
)
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args
|
||||
"test_tool", arguments=args, progress_callback=None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -821,5 +821,99 @@ class TestMCPTool:
|
||||
"X-Tenant-ID": "test-tenant",
|
||||
}
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args
|
||||
"test_tool", arguments=args, progress_callback=None
|
||||
)
|
||||
|
||||
def test_init_with_progress_callback(self):
|
||||
"""Test initialization with progress_callback."""
|
||||
|
||||
async def my_progress_callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
progress_callback=my_progress_callback,
|
||||
)
|
||||
|
||||
assert tool._progress_callback == my_progress_callback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_impl_with_progress_callback(self):
|
||||
"""Test running tool with progress_callback."""
|
||||
progress_updates = []
|
||||
|
||||
async def my_progress_callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
progress_updates.append((progress, total, message))
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
progress_callback=my_progress_callback,
|
||||
)
|
||||
|
||||
# Mock the session response
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
args = {"param1": "test_value"}
|
||||
|
||||
result = await tool._run_async_impl(
|
||||
args=args, tool_context=tool_context, credential=None
|
||||
)
|
||||
|
||||
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
|
||||
self.mock_session_manager.create_session.assert_called_once_with(
|
||||
headers=None
|
||||
)
|
||||
# Verify progress_callback was passed to call_tool
|
||||
self.mock_session.call_tool.assert_called_once_with(
|
||||
"test_tool", arguments=args, progress_callback=my_progress_callback
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_impl_with_progress_callback_factory(self):
|
||||
"""Test running tool with progress_callback factory that receives context."""
|
||||
factory_calls = []
|
||||
|
||||
def my_callback_factory(tool_name: str, *, callback_context=None, **kwargs):
|
||||
factory_calls.append((tool_name, callback_context))
|
||||
|
||||
async def callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
return callback
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
progress_callback=my_callback_factory,
|
||||
)
|
||||
|
||||
# Mock the session response
|
||||
mcp_response = CallToolResult(
|
||||
content=[TextContent(type="text", text="success")]
|
||||
)
|
||||
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
args = {"param1": "test_value"}
|
||||
|
||||
await tool._run_async_impl(
|
||||
args=args, tool_context=tool_context, credential=None
|
||||
)
|
||||
|
||||
# Verify factory was called with tool name and tool_context as callback_context
|
||||
assert len(factory_calls) == 1
|
||||
assert factory_calls[0][0] == "test_tool"
|
||||
# callback_context is the tool_context itself (ToolContext extends CallbackContext)
|
||||
assert factory_calls[0][1] is tool_context
|
||||
|
||||
@@ -360,6 +360,101 @@ class TestMcpToolset:
|
||||
assert tools[0].name == "tool1"
|
||||
assert tools[1].name == "tool2"
|
||||
|
||||
def test_init_with_progress_callback(self):
|
||||
"""Test initialization with progress_callback."""
|
||||
|
||||
async def my_progress_callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
progress_callback=my_progress_callback,
|
||||
)
|
||||
|
||||
assert toolset._progress_callback == my_progress_callback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_passes_progress_callback_to_mcp_tools(self):
|
||||
"""Test that get_tools passes progress_callback to created MCPTool instances."""
|
||||
progress_updates = []
|
||||
|
||||
async def my_progress_callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
progress_updates.append((progress, total, message))
|
||||
|
||||
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
|
||||
self.mock_session.list_tools = AsyncMock(
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
progress_callback=my_progress_callback,
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
# Verify each tool has the progress_callback set
|
||||
for tool in tools:
|
||||
assert tool._progress_callback == my_progress_callback
|
||||
|
||||
def test_init_with_progress_callback_factory(self):
|
||||
"""Test initialization with a ProgressCallbackFactory."""
|
||||
|
||||
def my_callback_factory(tool_name: str, *, readonly_context=None, **kwargs):
|
||||
async def callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
return callback
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
progress_callback=my_callback_factory,
|
||||
)
|
||||
|
||||
assert toolset._progress_callback == my_callback_factory
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_passes_factory_to_mcp_tools(self):
|
||||
"""Test that get_tools passes factory directly to MCPTool instances.
|
||||
|
||||
The factory is resolved at runtime in McpTool._run_async_impl, not at
|
||||
tool creation time. This allows the factory to receive ReadonlyContext.
|
||||
"""
|
||||
|
||||
def my_callback_factory(tool_name: str, *, readonly_context=None, **kwargs):
|
||||
async def callback(
|
||||
progress: float, total: float | None, message: str | None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
return callback
|
||||
|
||||
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
|
||||
self.mock_session.list_tools = AsyncMock(
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
progress_callback=my_callback_factory,
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
tools = await toolset.get_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
# Factory is passed directly to each tool (resolved at runtime)
|
||||
for tool in tools:
|
||||
assert tool._progress_callback == my_callback_factory
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_resources(self):
|
||||
"""Test listing resources."""
|
||||
|
||||
Reference in New Issue
Block a user