fix: pass tool context into require_confirmation function in McpTool

Aligns with the FunctionTool implementation of require_confirmation. This fixes https://github.com/google/adk-python/issues/4327.

Co-authored-by: Kathy Wu <wukathy@google.com>
PiperOrigin-RevId: 865566362
This commit is contained in:
Kathy Wu
2026-02-04 13:55:08 -08:00
committed by Copybara-Service
parent ac1401bd44
commit 7deffb16fd
2 changed files with 69 additions and 1 deletions
+24 -1
View File
@@ -151,8 +151,31 @@ class McpTool(BaseAuthenticatedTool):
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
if isinstance(self._require_confirmation, Callable):
args_to_call = args.copy()
try:
signature = inspect.signature(self._require_confirmation)
valid_params = set(signature.parameters.keys())
has_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in signature.parameters.values()
)
if "tool_context" in valid_params or has_kwargs:
args_to_call["tool_context"] = tool_context
# Filter args_to_call only if there's no **kwargs
if not has_kwargs:
# Add tool_context to valid_params if it was added to args_to_call
if "tool_context" in args_to_call:
valid_params.add("tool_context")
args_to_call = {
k: v for k, v in args_to_call.items() if k in valid_params
}
except ValueError:
args_to_call = args
require_confirmation = await self._invoke_callable(
self._require_confirmation, args
self._require_confirmation, args_to_call
)
else:
require_confirmation = bool(self._require_confirmation)
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from unittest.mock import AsyncMock
from unittest.mock import Mock
from unittest.mock import patch
@@ -669,6 +670,50 @@ class TestMCPTool:
args=args, tool_context=tool_context
)
@pytest.mark.asyncio
async def test_run_async_require_confirmation_callable_with_arg_filtering(
self,
):
"""Test require_confirmation=callable with argument filtering."""
async def _require_confirmation_func(
param1: str, tool_context: ToolContext
):
return True
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
require_confirmation=_require_confirmation_func,
)
tool_context = Mock(spec=ToolContext)
tool_context.tool_confirmation = None
tool_context.request_confirmation = Mock()
args = {"param1": "test_value", "extra_arg": 123}
with patch.object(
tool, "_invoke_callable", new_callable=AsyncMock
) as mock_invoke_callable:
mock_invoke_callable.return_value = (
True # Mock the return of require_confirmation
)
result = await tool.run_async(args=args, tool_context=tool_context)
expected_args_to_call = {
"param1": "test_value",
"tool_context": tool_context,
}
mock_invoke_callable.assert_called_once_with(
_require_confirmation_func, expected_args_to_call
)
assert result == {
"error": (
"This tool call requires confirmation, please approve or reject."
)
}
tool_context.request_confirmation.assert_called_once()
@pytest.mark.asyncio
async def test_run_async_require_confirmation_callable_true_no_confirmation(
self,