You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: start and close ClientSession in a single task in McpSessionManager
Merge https://github.com/google/adk-python/pull/4025 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: - #3950 - #3731 - #3708 **2. Or, if no issue exists, describe the change:** **Problem:** - `ClientSession` of https://github.com/modelcontextprotocol/python-sdk uses AnyIO for async task management. - AnyIO TaskGroup requires its start and close must happen in a same task. - Since `McpSessionManager` does not create task per client, the client might be closed by different task, cause the error: `Attempted to exit cancel scope in a different task than it was entered in`. **Solution:** I Suggest 2 changes: Handling the `ClientSession` in a single task - To start and close `ClientSession` by the same task, we need to wrap the whole lifecycle of `ClientSession` to a single task. - `SessionContext` wraps the initialization and disposal of `ClientSession` to a single task, ensures that the `ClientSession` will be handled only in a dedicated task. Add timeout for `ClientSession` - Since now we are using task per `ClientSession`, task should never be leaked. - But `McpSessionManager` does not deliver timeout directly to `ClientSession` when the type is not STDIO. - There is only timeout for `httpx` client when MCP type is SSE or StreamableHTTP. - But the timeout applys only to `httpx` client, so if there is an issue in MCP client itself(e.g. https://github.com/modelcontextprotocol/python-sdk/issues/262), a tool call waits the result **FOREVER**! - To overcome this issue, I propagated the `sse_read_timeout` to `ClientSession`. - `timeout` is too short for timeout for tool call, since its default value is only 5s. - `sse_read_timeout` is originally made for read timeout of SSE(default value of 5m or 300s), but actually most of SSE implementations from server (e.g. FastAPI, etc.) sends ping periodically(about 15s I assume), so in a normal circumstances this timeout is quite useless. - If the server does not send ping, the timeout is equal to tool call timeout. Therefore, it would be appropriate to use `sse_read_timeout` as tool call timeout. - Most of tool calls should finish within 5 minutes, and sse timeout is adjustable if not. - If this change is not acceptable, we could make a dedicate parameter for tool call timeout(e.g. `tool_call_timeout`). ### Testing Plan - Although this does not change the interface itself, it changes its own session management logics, some existing tests are no longer valid. - I made changes to those tests, especially those of which validate session states(e.g. checking whether `initialize()` called). - Since now session is encapsulated with `SessionContext`, we cannot validate the initialized state of the session in `TestMcpSessionManager`, should validate it at `TestSessionContext`. - Added a simple test for reproducing the issue(`test_create_and_close_session_in_different_tasks`). - Also made a test for the new component: `SessionContext`. **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. ```plaintext =================================================================================== 3689 passed, 1 skipped, 2205 warnings in 63.39s (0:01:03) =================================================================================== ``` **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [ ] ~~Any dependent changes have been merged and published in downstream modules.~~ `no deps has been changed` ### Additional context This PR is related to https://github.com/modelcontextprotocol/python-sdk/pull/1817 since it also fixes endless tool call awaiting. Co-authored-by: Kathy Wu <wukathy@google.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4025 from challenger71498:feat/task-based-mcp-session-manager f7f7cd0c9c96840361c30499d08c33a189f57d86 PiperOrigin-RevId: 856438147
This commit is contained in:
committed by
Copybara-Service
parent
1133ce219c
commit
cce430da79
@@ -41,6 +41,8 @@ from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .session_context import SessionContext
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
@@ -385,29 +387,27 @@ class MCPSessionManager:
|
||||
if hasattr(self._connection_params, 'timeout')
|
||||
else None
|
||||
)
|
||||
sse_read_timeout_in_seconds = (
|
||||
self._connection_params.sse_read_timeout
|
||||
if hasattr(self._connection_params, 'sse_read_timeout')
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
client = self._create_client(merged_headers)
|
||||
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
|
||||
|
||||
transports = await asyncio.wait_for(
|
||||
exit_stack.enter_async_context(client),
|
||||
session = await asyncio.wait_for(
|
||||
exit_stack.enter_async_context(
|
||||
SessionContext(
|
||||
client=client,
|
||||
timeout=timeout_in_seconds,
|
||||
sse_read_timeout=sse_read_timeout_in_seconds,
|
||||
is_stdio=is_stdio,
|
||||
)
|
||||
),
|
||||
timeout=timeout_in_seconds,
|
||||
)
|
||||
# The streamable http client returns a GetSessionCallback in addition to the
|
||||
# read/write MemoryObjectStreams needed to build the ClientSession, we limit
|
||||
# then to the two first values to be compatible with all clients.
|
||||
if isinstance(self._connection_params, StdioConnectionParams):
|
||||
session = await exit_stack.enter_async_context(
|
||||
ClientSession(
|
||||
*transports[:2],
|
||||
read_timeout_seconds=timedelta(seconds=timeout_in_seconds),
|
||||
)
|
||||
)
|
||||
else:
|
||||
session = await exit_stack.enter_async_context(
|
||||
ClientSession(*transports[:2])
|
||||
)
|
||||
await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds)
|
||||
|
||||
# Store session and exit stack in the pool
|
||||
self._sessions[session_key] = (session, exit_stack)
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
from typing import AsyncContextManager
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class SessionContext:
|
||||
"""Represents the context of a single MCP session within a dedicated task.
|
||||
|
||||
AnyIO's TaskGroup/CancelScope requires that the start and end of a scope
|
||||
occur within the same task. Since MCP clients use AnyIO internally, we need
|
||||
to ensure that the client's entire lifecycle (creation, usage, and cleanup)
|
||||
happens within a single dedicated task.
|
||||
|
||||
This class spawns a background task that:
|
||||
1. Enters the MCP client's async context and initializes the session
|
||||
2. Signals readiness via an asyncio.Event
|
||||
3. Waits for a close signal
|
||||
4. Cleans up the client within the same task
|
||||
|
||||
This ensures CancelScope constraints are satisfied regardless of which
|
||||
task calls start() or close().
|
||||
|
||||
Can be used in two ways:
|
||||
1. Direct method calls: start() and close()
|
||||
2. As an async context manager: async with lifecycle as session: ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: AsyncContextManager,
|
||||
timeout: Optional[float],
|
||||
sse_read_timeout: Optional[float],
|
||||
is_stdio: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
client: An MCP client context manager (e.g., from streamablehttp_client,
|
||||
sse_client, or stdio_client).
|
||||
timeout: Timeout in seconds for connection and initialization.
|
||||
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
|
||||
server.
|
||||
is_stdio: Whether this is a stdio connection (affects read timeout).
|
||||
"""
|
||||
self._client = client
|
||||
self._timeout = timeout
|
||||
self._sse_read_timeout = sse_read_timeout
|
||||
self._is_stdio = is_stdio
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._close_event = asyncio.Event()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def session(self) -> Optional[ClientSession]:
|
||||
"""Get the managed ClientSession, if available."""
|
||||
return self._session
|
||||
|
||||
async def start(self) -> ClientSession:
|
||||
"""Start the runner and wait for the session to be ready.
|
||||
|
||||
Returns:
|
||||
The initialized ClientSession.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If session creation fails.
|
||||
"""
|
||||
async with self._task_lock:
|
||||
if self._session:
|
||||
logger.debug(
|
||||
'Session has already been created, returning existing session'
|
||||
)
|
||||
return self._session
|
||||
|
||||
if self._close_event.is_set():
|
||||
raise ConnectionError(
|
||||
'Failed to create MCP session: session already closed'
|
||||
)
|
||||
|
||||
if not self._task:
|
||||
self._task = asyncio.create_task(self._run())
|
||||
|
||||
await self._ready_event.wait()
|
||||
|
||||
if self._task.cancelled():
|
||||
raise ConnectionError('Failed to create MCP session: task cancelled')
|
||||
|
||||
if self._task.done() and self._task.exception():
|
||||
raise ConnectionError(
|
||||
f'Failed to create MCP session: {self._task.exception()}'
|
||||
) from self._task.exception()
|
||||
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Signal the context task to close and wait for cleanup."""
|
||||
# Set the close event to signal the task to close.
|
||||
# Even if start has not been called, we need to set the close event
|
||||
# to signal the task to close right away.
|
||||
async with self._task_lock:
|
||||
self._close_event.set()
|
||||
|
||||
# If start has not been called, only set the close event and return
|
||||
if not self._task:
|
||||
return
|
||||
|
||||
if not self._ready_event.is_set():
|
||||
self._task.cancel()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._task, timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning('Failed to close MCP session: task timed out')
|
||||
self._task.cancel()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to close MCP session: {e}')
|
||||
|
||||
async def __aenter__(self) -> ClientSession:
|
||||
return await self.start()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def _run(self):
|
||||
"""Run the complete session context within a single task."""
|
||||
try:
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
transports = await asyncio.wait_for(
|
||||
exit_stack.enter_async_context(self._client),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
# The streamable http client returns a GetSessionCallback in addition
|
||||
# to the read/write MemoryObjectStreams needed to build the
|
||||
# ClientSession. We limit to the first two values to be compatible
|
||||
# with all clients.
|
||||
if self._is_stdio:
|
||||
session = await exit_stack.enter_async_context(
|
||||
ClientSession(
|
||||
*transports[:2],
|
||||
read_timeout_seconds=timedelta(seconds=self._timeout)
|
||||
if self._timeout is not None
|
||||
else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# For SSE and Streamable HTTP clients, use the sse_read_timeout
|
||||
# instead of the connection timeout as the read_timeout for the session.
|
||||
session = await exit_stack.enter_async_context(
|
||||
ClientSession(
|
||||
*transports[:2],
|
||||
read_timeout_seconds=timedelta(seconds=self._sse_read_timeout)
|
||||
if self._sse_read_timeout is not None
|
||||
else None,
|
||||
)
|
||||
)
|
||||
await asyncio.wait_for(session.initialize(), timeout=self._timeout)
|
||||
logger.debug('Session has been successfully initialized')
|
||||
|
||||
self._session = session
|
||||
self._ready_event.set()
|
||||
|
||||
# Wait for close signal - the session remains valid while we wait
|
||||
await self._close_event.wait()
|
||||
except BaseException as e:
|
||||
logger.warning(f'Error on session runner task: {e}')
|
||||
raise
|
||||
finally:
|
||||
self._ready_event.set()
|
||||
self._close_event.set()
|
||||
@@ -56,6 +56,33 @@ class MockAsyncExitStack:
|
||||
pass
|
||||
|
||||
|
||||
class MockSessionContext:
|
||||
"""Mock SessionContext for testing."""
|
||||
|
||||
def __init__(self, session=None):
|
||||
"""Initialize MockSessionContext.
|
||||
|
||||
Args:
|
||||
session: The mock session to return from __aenter__ and session property.
|
||||
"""
|
||||
self._session = session
|
||||
self._aenter_mock = AsyncMock(return_value=session)
|
||||
self._aexit_mock = AsyncMock(return_value=False)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
"""Get the mock session."""
|
||||
return self._session
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the async context manager."""
|
||||
return await self._aenter_mock()
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the async context manager."""
|
||||
return await self._aexit_mock(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
class TestMCPSessionManager:
|
||||
"""Test suite for MCPSessionManager class."""
|
||||
|
||||
@@ -241,7 +268,6 @@ class TestMCPSessionManager:
|
||||
"""Test creating a new stdio session."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
|
||||
mock_session = MockClientSession()
|
||||
mock_exit_stack = MockAsyncExitStack()
|
||||
|
||||
with patch(
|
||||
@@ -251,17 +277,19 @@ class TestMCPSessionManager:
|
||||
"google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack"
|
||||
) as mock_exit_stack_class:
|
||||
with patch(
|
||||
"google.adk.tools.mcp_tool.mcp_session_manager.ClientSession"
|
||||
) as mock_session_class:
|
||||
"google.adk.tools.mcp_tool.mcp_session_manager.SessionContext"
|
||||
) as mock_session_context_class:
|
||||
|
||||
# Setup mocks
|
||||
mock_exit_stack_class.return_value = mock_exit_stack
|
||||
mock_stdio.return_value = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
("read", "write"), # First call returns transports
|
||||
mock_session, # Second call returns session
|
||||
]
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Mock SessionContext using MockSessionContext
|
||||
# Create a mock session that will be returned by SessionContext
|
||||
mock_session = AsyncMock()
|
||||
mock_session_context = MockSessionContext(session=mock_session)
|
||||
mock_session_context_class.return_value = mock_session_context
|
||||
mock_exit_stack.enter_async_context.return_value = mock_session
|
||||
|
||||
# Create session
|
||||
session = await manager.create_session()
|
||||
@@ -271,8 +299,10 @@ class TestMCPSessionManager:
|
||||
assert len(manager._sessions) == 1
|
||||
assert "stdio_session" in manager._sessions
|
||||
|
||||
# Verify session was initialized
|
||||
mock_session.initialize.assert_called_once()
|
||||
# Verify SessionContext was created
|
||||
mock_session_context_class.assert_called_once()
|
||||
# Verify enter_async_context was called (which internally calls __aenter__)
|
||||
mock_exit_stack.enter_async_context.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_reuse_existing(self):
|
||||
@@ -300,39 +330,37 @@ class TestMCPSessionManager:
|
||||
@pytest.mark.asyncio
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client")
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack")
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.ClientSession")
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext")
|
||||
async def test_create_session_timeout(
|
||||
self, mock_session_class, mock_exit_stack_class, mock_stdio
|
||||
self, mock_session_context_class, mock_exit_stack_class, mock_stdio
|
||||
):
|
||||
"""Test session creation timeout."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
|
||||
mock_session = MockClientSession()
|
||||
mock_exit_stack = MockAsyncExitStack()
|
||||
|
||||
mock_exit_stack_class.return_value = mock_exit_stack
|
||||
mock_stdio.return_value = AsyncMock()
|
||||
mock_exit_stack.enter_async_context.side_effect = [
|
||||
("read", "write"), # First call returns transports
|
||||
mock_session, # Second call returns session
|
||||
]
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Simulate timeout during session initialization
|
||||
mock_session.initialize.side_effect = asyncio.TimeoutError("Test timeout")
|
||||
# Mock SessionContext
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(
|
||||
return_value=MockClientSession()
|
||||
)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session_context_class.return_value = mock_session_context
|
||||
|
||||
# Mock enter_async_context to raise TimeoutError (simulating asyncio.wait_for timeout)
|
||||
mock_exit_stack.enter_async_context = AsyncMock(
|
||||
side_effect=asyncio.TimeoutError("Test timeout")
|
||||
)
|
||||
|
||||
# Expect ConnectionError due to timeout
|
||||
with pytest.raises(ConnectionError, match="Failed to create MCP session"):
|
||||
await manager.create_session()
|
||||
|
||||
# Verify ClientSession called with timeout
|
||||
mock_session_class.assert_called_with(
|
||||
"read",
|
||||
"write",
|
||||
read_timeout_seconds=timedelta(
|
||||
seconds=manager._connection_params.timeout
|
||||
),
|
||||
)
|
||||
# Verify SessionContext was created
|
||||
mock_session_context_class.assert_called_once()
|
||||
# Verify session was not added to pool
|
||||
assert not manager._sessions
|
||||
# Verify cleanup was called
|
||||
@@ -390,6 +418,36 @@ class TestMCPSessionManager:
|
||||
assert "Warning: Error during MCP session cleanup" in error_output
|
||||
assert "Close error 1" in error_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client")
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack")
|
||||
@patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext")
|
||||
async def test_create_and_close_session_in_different_tasks(
|
||||
self, mock_session_context_class, mock_exit_stack_class, mock_stdio
|
||||
):
|
||||
"""Test creating and closing a session in different tasks."""
|
||||
manager = MCPSessionManager(self.mock_stdio_connection_params)
|
||||
|
||||
mock_exit_stack_class.return_value = MockAsyncExitStack()
|
||||
mock_stdio.return_value = AsyncMock()
|
||||
|
||||
# Mock SessionContext
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(
|
||||
return_value=MockClientSession()
|
||||
)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session_context_class.return_value = mock_session_context
|
||||
|
||||
# Create session in a new task
|
||||
await asyncio.create_task(manager.create_session())
|
||||
|
||||
# Close session in another task
|
||||
await asyncio.create_task(manager.close())
|
||||
|
||||
# Verify session was closed
|
||||
assert not manager._sessions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_on_errors_decorator():
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user