fix: Update the retry_on_closed_resource decorator to retry on all errors

Retrying only on closed_resource error is not enough to be reliable for production environments due to the other network errors that may occur -- remote protocol error, read timeout, etc. We will update this to retry on all errors. Since it is only a one-time retry, it should not affect latency significantly. Fixes https://github.com/google/adk-python/issues/2561.

PiperOrigin-RevId: 831153803
This commit is contained in:
Google Team Member
2025-11-11 18:43:46 -08:00
committed by Copybara-Service
parent 999af55880
commit 3674fbbe8f
4 changed files with 17 additions and 17 deletions
@@ -108,10 +108,10 @@ class StreamableHTTPConnectionParams(BaseModel):
terminate_on_close: bool = True
def retry_on_errors(func):
"""Decorator to automatically retry action when MCP session errors occur.
def retry_on_closed_resource(func):
"""Decorator to automatically retry action when MCP session is closed.
When MCP session errors occur, the decorator will automatically retry the
When MCP session was closed, the decorator will automatically retry the
action once. The create_session method will handle creating a new session
if the old one was disconnected.
@@ -126,11 +126,11 @@ def retry_on_errors(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except Exception as e:
# If an error is thrown, we will retry the function to reconnect to the
# server. create_session will handle detecting and replacing disconnected
# sessions.
logger.info('Retrying %s due to error: %s', func.__name__, e)
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
# If the session connection is closed or unusable, we will retry the
# function to reconnect to the server. create_session will handle
# detecting and replacing disconnected sessions.
logger.info('Retrying %s due to closed resource', func.__name__)
return await func(self, *args, **kwargs)
return wrapper
+2 -2
View File
@@ -32,7 +32,7 @@ from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from .._gemini_schema_util import _to_gemini_schema
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_errors
from .mcp_session_manager import retry_on_closed_resource
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
# their Python version to 3.10 if it fails.
@@ -184,7 +184,7 @@ class McpTool(BaseAuthenticatedTool):
return {"error": "This tool call is rejected."}
return await super().run_async(args=args, tool_context=tool_context)
@retry_on_errors
@retry_on_closed_resource
@override
async def _run_async_impl(
self, *, args, tool_context: ToolContext, credential: AuthCredential
+2 -2
View File
@@ -37,7 +37,7 @@ from ..base_toolset import ToolPredicate
from ..tool_configs import BaseToolConfig
from ..tool_configs import ToolArgsConfig
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_errors
from .mcp_session_manager import retry_on_closed_resource
from .mcp_session_manager import SseConnectionParams
from .mcp_session_manager import StdioConnectionParams
from .mcp_session_manager import StreamableHTTPConnectionParams
@@ -155,7 +155,7 @@ class McpToolset(BaseToolset):
self._auth_credential = auth_credential
self._require_confirmation = require_confirmation
@retry_on_errors
@retry_on_closed_resource
async def get_tools(
self,
readonly_context: Optional[ReadonlyContext] = None,
@@ -32,7 +32,7 @@ pytestmark = pytest.mark.skipif(
# Import dependencies with version checking
try:
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors
from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
@@ -44,7 +44,7 @@ except ImportError as e:
pass
MCPSessionManager = DummyClass
retry_on_errors = lambda x: x
retry_on_closed_resource = lambda x: x
SseConnectionParams = DummyClass
StdioConnectionParams = DummyClass
StreamableHTTPConnectionParams = DummyClass
@@ -375,12 +375,12 @@ class TestMCPSessionManager:
assert "Close error 1" in error_output
def test_retry_on_errors_decorator():
"""Test the retry_on_errors decorator."""
def test_retry_on_closed_resource_decorator():
"""Test the retry_on_closed_resource decorator."""
call_count = 0
@retry_on_errors
@retry_on_closed_resource
async def mock_function(self):
nonlocal call_count
call_count += 1