diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 0e9b9386..f4339f86 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -500,6 +500,30 @@ class MCPSessionManager: ) raise ConnectionError(f'Failed to create MCP session: {e}') from e + def __getstate__(self): + """Custom pickling to exclude non-picklable runtime objects.""" + state = self.__dict__.copy() + # Remove unpicklable entries or those that shouldn't persist across pickle + state['_sessions'] = {} + state['_session_lock_map'] = {} + + # Locks and file-like objects cannot be pickled + state.pop('_lock_map_lock', None) + state.pop('_errlog', None) + + return state + + def __setstate__(self, state): + """Custom unpickling to restore state.""" + self.__dict__.update(state) + # Re-initialize members that were not pickled + self._sessions = {} + self._session_lock_map = {} + self._lock_map_lock = threading.Lock() + # If _errlog was removed during pickling, default to sys.stderr + if not hasattr(self, '_errlog') or self._errlog is None: + self._errlog = sys.stderr + async def close(self): """Closes all sessions and cleans up resources.""" async with self._session_lock: diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index ba9a7e80..327df114 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -607,6 +607,39 @@ class TestMCPSessionManager: exit_stack2.aclose.assert_not_called() assert len(manager._sessions) == 0 + @pytest.mark.asyncio + async def test_pickle_mcp_session_manager(self): + """Verify that MCPSessionManager can be pickled and unpickled.""" + import pickle + + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Access the lock to ensure it's initialized + lock = manager._session_lock + assert isinstance(lock, asyncio.Lock) + + # Add a mock session to verify it's cleared on pickling + manager._sessions["test"] = (Mock(), Mock(), asyncio.get_running_loop()) + + # Pickle and unpickle + pickled = pickle.dumps(manager) + unpickled = pickle.loads(pickled) + + # Verify basics are restored + assert unpickled._connection_params == manager._connection_params + + # Verify transient/unpicklable members are re-initialized or cleared + assert unpickled._sessions == {} + assert unpickled._session_lock_map == {} + assert isinstance(unpickled._lock_map_lock, type(manager._lock_map_lock)) + assert unpickled._lock_map_lock is not manager._lock_map_lock + assert unpickled._errlog == sys.stderr + + # Verify we can still get a lock in the new instance + new_lock = unpickled._session_lock + assert isinstance(new_lock, asyncio.Lock) + assert new_lock is not lock + @pytest.mark.asyncio async def test_retry_on_errors_decorator():