From 13a95c463da537dacaa96bee852c51cd62b2fcc5 Mon Sep 17 00:00:00 2001 From: "Xinran (Sherry) Tang" Date: Fri, 19 Sep 2025 11:20:56 -0700 Subject: [PATCH] feat: Add get_events util function in invocation_context PiperOrigin-RevId: 809111315 --- src/google/adk/agents/invocation_context.py | 29 +++++ .../agents/test_invocation_context.py | 119 ++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 tests/unittests/agents/test_invocation_context.py diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 18833d99..d5aaada6 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -25,10 +25,12 @@ from pydantic import PrivateAttr from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..events.event import Event from ..memory.base_memory_service import BaseMemoryService from ..plugins.plugin_manager import PluginManager from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session +from ..utils.feature_decorator import working_in_progress from .active_streaming_tool import ActiveStreamingTool from .base_agent import BaseAgent from .live_request_queue import LiveRequestQueue @@ -215,6 +217,33 @@ class InvocationContext(BaseModel): def user_id(self) -> str: return self.session.user_id + @working_in_progress("incomplete feature, don't use yet") + def get_events( + self, + current_invocation: bool = False, + current_branch: bool = False, + ) -> list[Event]: + """Returns the events from the current session. + + Args: + current_invocation: Whether to filter the events by the current + invocation. + current_branch: Whether to filter the events by the current branch. + + Returns: + A list of events from the current session. + """ + results = self.session.events + if current_invocation: + results = [ + event + for event in results + if event.invocation_id == self.invocation_id + ] + if current_branch: + results = [event for event in results if event.branch == self.branch] + return results + def new_invocation_context_id() -> str: return "e-" + str(uuid.uuid4()) diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py new file mode 100644 index 00000000..f85bee1f --- /dev/null +++ b/tests/unittests/agents/test_invocation_context.py @@ -0,0 +1,119 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.session import Session +import pytest + + +class TestInvocationContext: + """Test suite for InvocationContext.""" + + @pytest.fixture + def mock_events(self): + """Create mock events for testing.""" + event1 = Mock(spec=Event) + event1.invocation_id = 'inv_1' + event1.branch = 'agent_1' + + event2 = Mock(spec=Event) + event2.invocation_id = 'inv_1' + event2.branch = 'agent_2' + + event3 = Mock(spec=Event) + event3.invocation_id = 'inv_2' + event3.branch = 'agent_1' + + event4 = Mock(spec=Event) + event4.invocation_id = 'inv_2' + event4.branch = 'agent_2' + + return [event1, event2, event3, event4] + + @pytest.fixture + def mock_invocation_context(self, mock_events): + """Create a mock invocation context for testing.""" + ctx = InvocationContext( + session_service=Mock(spec=BaseSessionService), + agent=Mock(spec=BaseAgent), + invocation_id='inv_1', + branch='agent_1', + session=Mock(spec=Session, events=mock_events), + ) + return ctx + + def test_get_events_returns_all_events_by_default( + self, mock_invocation_context, mock_events + ): + """Tests that get_events returns all events when no filters are applied.""" + events = mock_invocation_context.get_events() + assert events == mock_events + + def test_get_events_filters_by_current_invocation( + self, mock_invocation_context, mock_events + ): + """Tests that get_events correctly filters by the current invocation.""" + event1, event2, _, _ = mock_events + events = mock_invocation_context.get_events(current_invocation=True) + assert events == [event1, event2] + + def test_get_events_filters_by_current_branch( + self, mock_invocation_context, mock_events + ): + """Tests that get_events correctly filters by the current branch.""" + event1, _, event3, _ = mock_events + events = mock_invocation_context.get_events(current_branch=True) + assert events == [event1, event3] + + def test_get_events_filters_by_invocation_and_branch( + self, mock_invocation_context, mock_events + ): + """Tests that get_events filters by invocation and branch.""" + event1, _, _, _ = mock_events + events = mock_invocation_context.get_events( + current_invocation=True, + current_branch=True, + ) + assert events == [event1] + + def test_get_events_with_no_events_in_session(self, mock_invocation_context): + """Tests get_events when the session has no events.""" + mock_invocation_context.session.events = [] + events = mock_invocation_context.get_events() + assert not events + + def test_get_events_with_no_matching_events(self, mock_invocation_context): + """Tests get_events when no events match the filters.""" + mock_invocation_context.invocation_id = 'inv_3' + mock_invocation_context.branch = 'branch_C' + + # Filter by invocation + events = mock_invocation_context.get_events(current_invocation=True) + assert not events + + # Filter by branch + events = mock_invocation_context.get_events(current_branch=True) + assert not events + + # Filter by both + events = mock_invocation_context.get_events( + current_invocation=True, + current_branch=True, + ) + assert not events