diff --git a/src/google/adk/agents/callback_context.py b/src/google/adk/agents/callback_context.py index 08353860..37ceb176 100644 --- a/src/google/adk/agents/callback_context.py +++ b/src/google/adk/agents/callback_context.py @@ -176,3 +176,27 @@ class CallbackContext(ReadonlyContext): return await self._invocation_context.credential_service.load_credential( auth_config, self ) + + async def add_session_to_memory(self) -> None: + """Triggers memory generation for the current session. + + This method saves the current session's events to the memory service, + enabling the agent to recall information from past interactions. + + Raises: + ValueError: If memory service is not available. + + Example: + ```python + async def my_after_agent_callback(callback_context: CallbackContext): + # Save conversation to memory at the end of each interaction + await callback_context.add_session_to_memory() + ``` + """ + if self._invocation_context.memory_service is None: + raise ValueError( + "Cannot add session to memory: memory service is not available." + ) + await self._invocation_context.memory_service.add_session_to_memory( + self._invocation_context.session + ) diff --git a/tests/unittests/agents/test_callback_context.py b/tests/unittests/agents/test_callback_context.py index 586e95fa..f3f70249 100644 --- a/tests/unittests/agents/test_callback_context.py +++ b/tests/unittests/agents/test_callback_context.py @@ -321,3 +321,87 @@ class TestCallbackContext: version=None, ) assert result == test_artifact + + +class TestCallbackContextAddSessionToMemory: + """Test the add_session_to_memory method in CallbackContext.""" + + @pytest.mark.asyncio + async def test_add_session_to_memory_success(self, mock_invocation_context): + """Test that add_session_to_memory calls the memory service correctly.""" + memory_service = AsyncMock() + mock_invocation_context.memory_service = memory_service + + context = CallbackContext(mock_invocation_context) + await context.add_session_to_memory() + + memory_service.add_session_to_memory.assert_called_once_with( + mock_invocation_context.session + ) + + @pytest.mark.asyncio + async def test_add_session_to_memory_no_service_raises( + self, mock_invocation_context + ): + """Test that add_session_to_memory raises ValueError when memory service is None.""" + mock_invocation_context.memory_service = None + + context = CallbackContext(mock_invocation_context) + + with pytest.raises( + ValueError, + match=( + r"Cannot add session to memory: memory service is not available\." + ), + ): + await context.add_session_to_memory() + + @pytest.mark.asyncio + async def test_add_session_to_memory_passes_through_service_exceptions( + self, mock_invocation_context + ): + """Test that add_session_to_memory passes through exceptions from the memory service.""" + memory_service = AsyncMock() + memory_service.add_session_to_memory.side_effect = Exception( + "Memory service error" + ) + mock_invocation_context.memory_service = memory_service + + context = CallbackContext(mock_invocation_context) + + with pytest.raises(Exception, match="Memory service error"): + await context.add_session_to_memory() + + +class TestToolContextAddSessionToMemory: + """Test the add_session_to_memory method in ToolContext.""" + + @pytest.mark.asyncio + async def test_add_session_to_memory_success(self, mock_invocation_context): + """Test that ToolContext.add_session_to_memory calls the memory service correctly.""" + memory_service = AsyncMock() + mock_invocation_context.memory_service = memory_service + + tool_context = ToolContext(mock_invocation_context) + await tool_context.add_session_to_memory() + + memory_service.add_session_to_memory.assert_called_once_with( + mock_invocation_context.session + ) + + @pytest.mark.asyncio + async def test_add_session_to_memory_no_service_raises( + self, mock_invocation_context + ): + """Test that ToolContext.add_session_to_memory raises ValueError when memory service is None.""" + mock_invocation_context.memory_service = None + + tool_context = ToolContext(mock_invocation_context) + + with pytest.raises( + ValueError, + match=( + r"Cannot add session to memory: memory service is not available\." + ), + ): + await tool_context.add_session_to_memory()