You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: pass tool context into require_confirmation function in McpTool
Aligns with the FunctionTool implementation of require_confirmation. This fixes https://github.com/google/adk-python/issues/4327. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 865566362
This commit is contained in:
committed by
Copybara-Service
parent
ac1401bd44
commit
7deffb16fd
@@ -151,8 +151,31 @@ class McpTool(BaseAuthenticatedTool):
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
if isinstance(self._require_confirmation, Callable):
|
||||
args_to_call = args.copy()
|
||||
try:
|
||||
signature = inspect.signature(self._require_confirmation)
|
||||
valid_params = set(signature.parameters.keys())
|
||||
has_kwargs = any(
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for param in signature.parameters.values()
|
||||
)
|
||||
|
||||
if "tool_context" in valid_params or has_kwargs:
|
||||
args_to_call["tool_context"] = tool_context
|
||||
|
||||
# Filter args_to_call only if there's no **kwargs
|
||||
if not has_kwargs:
|
||||
# Add tool_context to valid_params if it was added to args_to_call
|
||||
if "tool_context" in args_to_call:
|
||||
valid_params.add("tool_context")
|
||||
args_to_call = {
|
||||
k: v for k, v in args_to_call.items() if k in valid_params
|
||||
}
|
||||
except ValueError:
|
||||
args_to_call = args
|
||||
|
||||
require_confirmation = await self._invoke_callable(
|
||||
self._require_confirmation, args
|
||||
self._require_confirmation, args_to_call
|
||||
)
|
||||
else:
|
||||
require_confirmation = bool(self._require_confirmation)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -669,6 +670,50 @@ class TestMCPTool:
|
||||
args=args, tool_context=tool_context
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_require_confirmation_callable_with_arg_filtering(
|
||||
self,
|
||||
):
|
||||
"""Test require_confirmation=callable with argument filtering."""
|
||||
|
||||
async def _require_confirmation_func(
|
||||
param1: str, tool_context: ToolContext
|
||||
):
|
||||
return True
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
require_confirmation=_require_confirmation_func,
|
||||
)
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
tool_context.tool_confirmation = None
|
||||
tool_context.request_confirmation = Mock()
|
||||
args = {"param1": "test_value", "extra_arg": 123}
|
||||
|
||||
with patch.object(
|
||||
tool, "_invoke_callable", new_callable=AsyncMock
|
||||
) as mock_invoke_callable:
|
||||
mock_invoke_callable.return_value = (
|
||||
True # Mock the return of require_confirmation
|
||||
)
|
||||
|
||||
result = await tool.run_async(args=args, tool_context=tool_context)
|
||||
expected_args_to_call = {
|
||||
"param1": "test_value",
|
||||
"tool_context": tool_context,
|
||||
}
|
||||
mock_invoke_callable.assert_called_once_with(
|
||||
_require_confirmation_func, expected_args_to_call
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"error": (
|
||||
"This tool call requires confirmation, please approve or reject."
|
||||
)
|
||||
}
|
||||
tool_context.request_confirmation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_require_confirmation_callable_true_no_confirmation(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user