feat: start and close ClientSession in a single task in McpSessionManager

Merge https://github.com/google/adk-python/pull/4025

**Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.**

### Link to Issue or Description of Change

**1. Link to an existing issue (if applicable):**

- Closes:
  - #3950
  - #3731
  - #3708

**2. Or, if no issue exists, describe the change:**

**Problem:**
- `ClientSession` of https://github.com/modelcontextprotocol/python-sdk uses AnyIO for async task management.
- AnyIO TaskGroup requires its start and close must happen in a same task.
- Since `McpSessionManager` does not create task per client, the client might be closed by different task, cause the error: `Attempted to exit cancel scope in a different task than it was entered in`.

**Solution:**

I Suggest 2 changes:

Handling the `ClientSession` in a single task
- To start and close `ClientSession` by the same task, we need to wrap the whole lifecycle of `ClientSession` to a single task.
- `SessionContext` wraps the initialization and disposal of `ClientSession` to a single task, ensures that the `ClientSession` will be handled only in a dedicated task.

Add timeout for `ClientSession`
- Since now we are using task per `ClientSession`, task should never be leaked.
- But `McpSessionManager` does not deliver timeout directly to `ClientSession` when the type is not STDIO.
  - There is only timeout for `httpx` client when MCP type is SSE or StreamableHTTP.
  - But the timeout applys only to `httpx` client, so if there is an issue in MCP client itself(e.g. https://github.com/modelcontextprotocol/python-sdk/issues/262), a tool call waits the result **FOREVER**!
- To overcome this issue, I propagated the `sse_read_timeout` to `ClientSession`.
  - `timeout` is too short for timeout for tool call, since its default value is only 5s.
  - `sse_read_timeout` is originally made for read timeout of SSE(default value of 5m or 300s), but actually most of SSE implementations from server (e.g. FastAPI, etc.) sends ping periodically(about 15s I assume), so in a normal circumstances this timeout is quite useless.
  - If the server does not send ping, the timeout is equal to tool call timeout. Therefore, it would be appropriate to use `sse_read_timeout` as tool call timeout.
  - Most of tool calls should finish within 5 minutes, and sse timeout is adjustable if not.
- If this change is not acceptable, we could make a dedicate parameter for tool call timeout(e.g. `tool_call_timeout`).

### Testing Plan
- Although this does not change the interface itself, it changes its own session management logics, some existing tests are no longer valid.
  - I made changes to those tests, especially those of which validate session states(e.g. checking whether `initialize()` called).
  - Since now session is encapsulated with `SessionContext`, we cannot validate the initialized state of the session in `TestMcpSessionManager`, should validate it at `TestSessionContext`.
- Added a simple test for reproducing the issue(`test_create_and_close_session_in_different_tasks`).
- Also made a test for the new component: `SessionContext`.

**Unit Tests:**

- [x] I have added or updated unit tests for my change.
- [x] All unit tests pass locally.

```plaintext
=================================================================================== 3689 passed, 1 skipped, 2205 warnings in 63.39s (0:01:03) ===================================================================================
```

**Manual End-to-End (E2E) Tests:**

_Please provide instructions on how to manually test your changes, including any
necessary setup or configuration. Please provide logs or screenshots to help
reviewers better understand the fix._

### Checklist

- [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document.
- [x] I have performed a self-review of my own code.
- [x] I have commented my code, particularly in hard-to-understand areas.
- [x] I have added tests that prove my fix is effective or that my feature works.
- [x] New and existing unit tests pass locally with my changes.
- [x] I have manually tested my changes end-to-end.
- [ ] ~~Any dependent changes have been merged and published in downstream modules.~~ `no deps has been changed`

### Additional context
This PR is related to https://github.com/modelcontextprotocol/python-sdk/pull/1817 since it also fixes endless tool call awaiting.

Co-authored-by: Kathy Wu <wukathy@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4025 from challenger71498:feat/task-based-mcp-session-manager f7f7cd0c9c96840361c30499d08c33a189f57d86
PiperOrigin-RevId: 856438147
This commit is contained in:
Kathy Wu
2026-01-14 18:09:33 -08:00
committed by Copybara-Service
parent 1133ce219c
commit cce430da79
4 changed files with 847 additions and 45 deletions
@@ -41,6 +41,8 @@ from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel
from pydantic import ConfigDict
from .session_context import SessionContext
logger = logging.getLogger('google_adk.' + __name__)
@@ -385,29 +387,27 @@ class MCPSessionManager:
if hasattr(self._connection_params, 'timeout')
else None
)
sse_read_timeout_in_seconds = (
self._connection_params.sse_read_timeout
if hasattr(self._connection_params, 'sse_read_timeout')
else None
)
try:
client = self._create_client(merged_headers)
is_stdio = isinstance(self._connection_params, StdioConnectionParams)
transports = await asyncio.wait_for(
exit_stack.enter_async_context(client),
session = await asyncio.wait_for(
exit_stack.enter_async_context(
SessionContext(
client=client,
timeout=timeout_in_seconds,
sse_read_timeout=sse_read_timeout_in_seconds,
is_stdio=is_stdio,
)
),
timeout=timeout_in_seconds,
)
# The streamable http client returns a GetSessionCallback in addition to the
# read/write MemoryObjectStreams needed to build the ClientSession, we limit
# then to the two first values to be compatible with all clients.
if isinstance(self._connection_params, StdioConnectionParams):
session = await exit_stack.enter_async_context(
ClientSession(
*transports[:2],
read_timeout_seconds=timedelta(seconds=timeout_in_seconds),
)
)
else:
session = await exit_stack.enter_async_context(
ClientSession(*transports[:2])
)
await asyncio.wait_for(session.initialize(), timeout=timeout_in_seconds)
# Store session and exit stack in the pool
self._sessions[session_key] = (session, exit_stack)
@@ -0,0 +1,194 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from contextlib import AsyncExitStack
from datetime import timedelta
import logging
from typing import AsyncContextManager
from typing import Optional
from mcp import ClientSession
logger = logging.getLogger('google_adk.' + __name__)
class SessionContext:
"""Represents the context of a single MCP session within a dedicated task.
AnyIO's TaskGroup/CancelScope requires that the start and end of a scope
occur within the same task. Since MCP clients use AnyIO internally, we need
to ensure that the client's entire lifecycle (creation, usage, and cleanup)
happens within a single dedicated task.
This class spawns a background task that:
1. Enters the MCP client's async context and initializes the session
2. Signals readiness via an asyncio.Event
3. Waits for a close signal
4. Cleans up the client within the same task
This ensures CancelScope constraints are satisfied regardless of which
task calls start() or close().
Can be used in two ways:
1. Direct method calls: start() and close()
2. As an async context manager: async with lifecycle as session: ...
"""
def __init__(
self,
client: AsyncContextManager,
timeout: Optional[float],
sse_read_timeout: Optional[float],
is_stdio: bool = False,
):
"""
Args:
client: An MCP client context manager (e.g., from streamablehttp_client,
sse_client, or stdio_client).
timeout: Timeout in seconds for connection and initialization.
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
server.
is_stdio: Whether this is a stdio connection (affects read timeout).
"""
self._client = client
self._timeout = timeout
self._sse_read_timeout = sse_read_timeout
self._is_stdio = is_stdio
self._session: Optional[ClientSession] = None
self._ready_event = asyncio.Event()
self._close_event = asyncio.Event()
self._task: Optional[asyncio.Task] = None
self._task_lock = asyncio.Lock()
@property
def session(self) -> Optional[ClientSession]:
"""Get the managed ClientSession, if available."""
return self._session
async def start(self) -> ClientSession:
"""Start the runner and wait for the session to be ready.
Returns:
The initialized ClientSession.
Raises:
ConnectionError: If session creation fails.
"""
async with self._task_lock:
if self._session:
logger.debug(
'Session has already been created, returning existing session'
)
return self._session
if self._close_event.is_set():
raise ConnectionError(
'Failed to create MCP session: session already closed'
)
if not self._task:
self._task = asyncio.create_task(self._run())
await self._ready_event.wait()
if self._task.cancelled():
raise ConnectionError('Failed to create MCP session: task cancelled')
if self._task.done() and self._task.exception():
raise ConnectionError(
f'Failed to create MCP session: {self._task.exception()}'
) from self._task.exception()
return self._session
async def close(self):
"""Signal the context task to close and wait for cleanup."""
# Set the close event to signal the task to close.
# Even if start has not been called, we need to set the close event
# to signal the task to close right away.
async with self._task_lock:
self._close_event.set()
# If start has not been called, only set the close event and return
if not self._task:
return
if not self._ready_event.is_set():
self._task.cancel()
try:
await asyncio.wait_for(self._task, timeout=self._timeout)
except asyncio.TimeoutError:
logger.warning('Failed to close MCP session: task timed out')
self._task.cancel()
except asyncio.CancelledError:
pass
except Exception as e:
logger.warning(f'Failed to close MCP session: {e}')
async def __aenter__(self) -> ClientSession:
return await self.start()
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def _run(self):
"""Run the complete session context within a single task."""
try:
async with AsyncExitStack() as exit_stack:
transports = await asyncio.wait_for(
exit_stack.enter_async_context(self._client),
timeout=self._timeout,
)
# The streamable http client returns a GetSessionCallback in addition
# to the read/write MemoryObjectStreams needed to build the
# ClientSession. We limit to the first two values to be compatible
# with all clients.
if self._is_stdio:
session = await exit_stack.enter_async_context(
ClientSession(
*transports[:2],
read_timeout_seconds=timedelta(seconds=self._timeout)
if self._timeout is not None
else None,
)
)
else:
# For SSE and Streamable HTTP clients, use the sse_read_timeout
# instead of the connection timeout as the read_timeout for the session.
session = await exit_stack.enter_async_context(
ClientSession(
*transports[:2],
read_timeout_seconds=timedelta(seconds=self._sse_read_timeout)
if self._sse_read_timeout is not None
else None,
)
)
await asyncio.wait_for(session.initialize(), timeout=self._timeout)
logger.debug('Session has been successfully initialized')
self._session = session
self._ready_event.set()
# Wait for close signal - the session remains valid while we wait
await self._close_event.wait()
except BaseException as e:
logger.warning(f'Error on session runner task: {e}')
raise
finally:
self._ready_event.set()
self._close_event.set()
@@ -56,6 +56,33 @@ class MockAsyncExitStack:
pass
class MockSessionContext:
"""Mock SessionContext for testing."""
def __init__(self, session=None):
"""Initialize MockSessionContext.
Args:
session: The mock session to return from __aenter__ and session property.
"""
self._session = session
self._aenter_mock = AsyncMock(return_value=session)
self._aexit_mock = AsyncMock(return_value=False)
@property
def session(self):
"""Get the mock session."""
return self._session
async def __aenter__(self):
"""Enter the async context manager."""
return await self._aenter_mock()
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit the async context manager."""
return await self._aexit_mock(exc_type, exc_val, exc_tb)
class TestMCPSessionManager:
"""Test suite for MCPSessionManager class."""
@@ -241,7 +268,6 @@ class TestMCPSessionManager:
"""Test creating a new stdio session."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
mock_session = MockClientSession()
mock_exit_stack = MockAsyncExitStack()
with patch(
@@ -251,17 +277,19 @@ class TestMCPSessionManager:
"google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack"
) as mock_exit_stack_class:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.ClientSession"
) as mock_session_class:
"google.adk.tools.mcp_tool.mcp_session_manager.SessionContext"
) as mock_session_context_class:
# Setup mocks
mock_exit_stack_class.return_value = mock_exit_stack
mock_stdio.return_value = AsyncMock()
mock_exit_stack.enter_async_context.side_effect = [
("read", "write"), # First call returns transports
mock_session, # Second call returns session
]
mock_session_class.return_value = mock_session
# Mock SessionContext using MockSessionContext
# Create a mock session that will be returned by SessionContext
mock_session = AsyncMock()
mock_session_context = MockSessionContext(session=mock_session)
mock_session_context_class.return_value = mock_session_context
mock_exit_stack.enter_async_context.return_value = mock_session
# Create session
session = await manager.create_session()
@@ -271,8 +299,10 @@ class TestMCPSessionManager:
assert len(manager._sessions) == 1
assert "stdio_session" in manager._sessions
# Verify session was initialized
mock_session.initialize.assert_called_once()
# Verify SessionContext was created
mock_session_context_class.assert_called_once()
# Verify enter_async_context was called (which internally calls __aenter__)
mock_exit_stack.enter_async_context.assert_called_once()
@pytest.mark.asyncio
async def test_create_session_reuse_existing(self):
@@ -300,39 +330,37 @@ class TestMCPSessionManager:
@pytest.mark.asyncio
@patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client")
@patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack")
@patch("google.adk.tools.mcp_tool.mcp_session_manager.ClientSession")
@patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext")
async def test_create_session_timeout(
self, mock_session_class, mock_exit_stack_class, mock_stdio
self, mock_session_context_class, mock_exit_stack_class, mock_stdio
):
"""Test session creation timeout."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
mock_session = MockClientSession()
mock_exit_stack = MockAsyncExitStack()
mock_exit_stack_class.return_value = mock_exit_stack
mock_stdio.return_value = AsyncMock()
mock_exit_stack.enter_async_context.side_effect = [
("read", "write"), # First call returns transports
mock_session, # Second call returns session
]
mock_session_class.return_value = mock_session
# Simulate timeout during session initialization
mock_session.initialize.side_effect = asyncio.TimeoutError("Test timeout")
# Mock SessionContext
mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(
return_value=MockClientSession()
)
mock_session_context.__aexit__ = AsyncMock(return_value=False)
mock_session_context_class.return_value = mock_session_context
# Mock enter_async_context to raise TimeoutError (simulating asyncio.wait_for timeout)
mock_exit_stack.enter_async_context = AsyncMock(
side_effect=asyncio.TimeoutError("Test timeout")
)
# Expect ConnectionError due to timeout
with pytest.raises(ConnectionError, match="Failed to create MCP session"):
await manager.create_session()
# Verify ClientSession called with timeout
mock_session_class.assert_called_with(
"read",
"write",
read_timeout_seconds=timedelta(
seconds=manager._connection_params.timeout
),
)
# Verify SessionContext was created
mock_session_context_class.assert_called_once()
# Verify session was not added to pool
assert not manager._sessions
# Verify cleanup was called
@@ -390,6 +418,36 @@ class TestMCPSessionManager:
assert "Warning: Error during MCP session cleanup" in error_output
assert "Close error 1" in error_output
@pytest.mark.asyncio
@patch("google.adk.tools.mcp_tool.mcp_session_manager.stdio_client")
@patch("google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack")
@patch("google.adk.tools.mcp_tool.mcp_session_manager.SessionContext")
async def test_create_and_close_session_in_different_tasks(
self, mock_session_context_class, mock_exit_stack_class, mock_stdio
):
"""Test creating and closing a session in different tasks."""
manager = MCPSessionManager(self.mock_stdio_connection_params)
mock_exit_stack_class.return_value = MockAsyncExitStack()
mock_stdio.return_value = AsyncMock()
# Mock SessionContext
mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(
return_value=MockClientSession()
)
mock_session_context.__aexit__ = AsyncMock(return_value=False)
mock_session_context_class.return_value = mock_session_context
# Create session in a new task
await asyncio.create_task(manager.create_session())
# Close session in another task
await asyncio.create_task(manager.close())
# Verify session was closed
assert not manager._sessions
@pytest.mark.asyncio
async def test_retry_on_errors_decorator():
File diff suppressed because it is too large Load Diff