diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 10de7c14..16732884 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -126,6 +126,7 @@ class RemoteA2aAgent(BaseAgent): a2a_request_meta_provider: Optional[ Callable[[InvocationContext, A2AMessage], dict[str, Any]] ] = None, + full_history_when_stateless: bool = False, **kwargs: Any, ) -> None: """Initialize RemoteA2aAgent. @@ -142,6 +143,10 @@ class RemoteA2aAgent(BaseAgent): a2a_request_meta_provider: Optional callable that takes InvocationContext and A2AMessage and returns a metadata object to attach to the A2A request. + full_history_when_stateless: If True, stateless agents (those that do not + return Tasks or context IDs) will receive all session events on every + request. If False, the default behavior of sending only events since the + last reply from the agent will be used. **kwargs: Additional arguments passed to BaseAgent Raises: @@ -168,6 +173,7 @@ class RemoteA2aAgent(BaseAgent): self._a2a_part_converter = a2a_part_converter self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory self._a2a_request_meta_provider = a2a_request_meta_provider + self._full_history_when_stateless = full_history_when_stateless # Validate and store agent card reference if isinstance(agent_card, AgentCard): @@ -365,7 +371,14 @@ class RemoteA2aAgent(BaseAgent): if event.custom_metadata: metadata = event.custom_metadata context_id = metadata.get(A2A_METADATA_PREFIX + "context_id") - break + # Historical note: this behavior originally always applied, regardless + # of whether the agent was stateful or stateless. However, only stateful + # agents can be expected to have previous events in the remote session. + # For backwards compatibility, we maintain this behavior when + # _full_history_when_stateless is false (the default) or if the agent + # is stateful (i.e. returned a context ID). + if not self._full_history_when_stateless or context_id: + break events_to_process.append(event) for event in reversed(events_to_process): diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 8bd4a22f..d395a551 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -665,6 +665,157 @@ class TestRemoteA2aAgentMessageHandling: assert parts == [] assert context_id is None + def test_construct_message_parts_from_session_stops_on_agent_reply(self): + """Test message parts construction stops on agent reply by default.""" + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = None + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id is None + + def test_construct_message_parts_from_session_stateless_full_history(self): + """Test full history for stateless agent when enabled.""" + self.agent._full_history_when_stateless = True + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = None + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 3 + assert parts[0].text == "User 1" + assert parts[1].text == "Agent 1" + assert parts[2].text == "User 2" + assert context_id is None + + def test_construct_message_parts_from_session_stateful_partial_history(self): + """Test partial history for stateful agent when full history is enabled.""" + self.agent._full_history_when_stateless = True + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = {A2A_METADATA_PREFIX + "context_id": "ctx-1"} + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id == "ctx-1" + @pytest.mark.asyncio async def test_handle_a2a_response_success_with_message(self): """Test successful A2A response handling with message."""