From 7deffb16fd652547a8c871826fa1788ec0f99294 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 4 Feb 2026 13:55:08 -0800 Subject: [PATCH] fix: pass tool context into require_confirmation function in McpTool Aligns with the FunctionTool implementation of require_confirmation. This fixes https://github.com/google/adk-python/issues/4327. Co-authored-by: Kathy Wu PiperOrigin-RevId: 865566362 --- src/google/adk/tools/mcp_tool/mcp_tool.py | 25 ++++++++++- .../unittests/tools/mcp_tool/test_mcp_tool.py | 45 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index a5b598fd..719ed662 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -151,8 +151,31 @@ class McpTool(BaseAuthenticatedTool): self, *, args: dict[str, Any], tool_context: ToolContext ) -> Any: if isinstance(self._require_confirmation, Callable): + args_to_call = args.copy() + try: + signature = inspect.signature(self._require_confirmation) + valid_params = set(signature.parameters.keys()) + has_kwargs = any( + param.kind == inspect.Parameter.VAR_KEYWORD + for param in signature.parameters.values() + ) + + if "tool_context" in valid_params or has_kwargs: + args_to_call["tool_context"] = tool_context + + # Filter args_to_call only if there's no **kwargs + if not has_kwargs: + # Add tool_context to valid_params if it was added to args_to_call + if "tool_context" in args_to_call: + valid_params.add("tool_context") + args_to_call = { + k: v for k, v in args_to_call.items() if k in valid_params + } + except ValueError: + args_to_call = args + require_confirmation = await self._invoke_callable( - self._require_confirmation, args + self._require_confirmation, args_to_call ) else: require_confirmation = bool(self._require_confirmation) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 0bf28cb3..09e3529a 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -669,6 +670,50 @@ class TestMCPTool: args=args, tool_context=tool_context ) + @pytest.mark.asyncio + async def test_run_async_require_confirmation_callable_with_arg_filtering( + self, + ): + """Test require_confirmation=callable with argument filtering.""" + + async def _require_confirmation_func( + param1: str, tool_context: ToolContext + ): + return True + + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + require_confirmation=_require_confirmation_func, + ) + tool_context = Mock(spec=ToolContext) + tool_context.tool_confirmation = None + tool_context.request_confirmation = Mock() + args = {"param1": "test_value", "extra_arg": 123} + + with patch.object( + tool, "_invoke_callable", new_callable=AsyncMock + ) as mock_invoke_callable: + mock_invoke_callable.return_value = ( + True # Mock the return of require_confirmation + ) + + result = await tool.run_async(args=args, tool_context=tool_context) + expected_args_to_call = { + "param1": "test_value", + "tool_context": tool_context, + } + mock_invoke_callable.assert_called_once_with( + _require_confirmation_func, expected_args_to_call + ) + + assert result == { + "error": ( + "This tool call requires confirmation, please approve or reject." + ) + } + tool_context.request_confirmation.assert_called_once() + @pytest.mark.asyncio async def test_run_async_require_confirmation_callable_true_no_confirmation( self,