From 9211f4ce8cc6d918df314d6a2ff13da2e0ef35fa Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 14 Nov 2025 15:21:56 -0800 Subject: [PATCH] fix: Use `async for` to loop through event iterator to get all events in vertex_ai_session_service Fix https://github.com/google/adk-python/issues/3559 Co-authored-by: Shangjie Chen PiperOrigin-RevId: 832476367 --- .../adk/sessions/vertex_ai_session_service.py | 2 +- .../test_vertex_ai_session_service.py | 63 ++++++++++++++++++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 571ae53a..252a69e0 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -178,7 +178,7 @@ class VertexAiSessionService(BaseSessionService): ) session.events += [ _from_api_event(event) - for event in events_iterator + async for event in events_iterator if event.timestamp.timestamp() <= update_timestamp ] diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 29c3c74c..fa80dc9a 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import copy +import datetime import re import types from typing import Any @@ -130,6 +130,36 @@ MOCK_SESSION_JSON_PAGE2 = { 'user_id': 'user_with_pages', } +MOCK_SESSION_JSON_5 = { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/5' + ), + 'update_time': '2024-12-12T12:15:12.123456Z', + 'user_id': 'user_with_many_events', +} + + +def _generate_mock_events_for_session_5(num_events): + events = [] + start_time = isoparse('2024-12-12T12:12:12.123456Z') + for i in range(num_events): + event_time = start_time + datetime.timedelta(microseconds=i * 1000) + events.append({ + 'name': ( + 'projects/test-project/locations/test-location/' + f'reasoningEngines/123/sessions/5/events/{i}' + ), + 'invocation_id': f'invocation_{i}', + 'author': 'user_with_many_events', + 'timestamp': event_time.isoformat().replace('+00:00', 'Z'), + }) + return events + + +MANY_EVENTS_COUNT = 200 +MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT) + MOCK_SESSION = Session( app_name='123', user_id='user', @@ -228,6 +258,11 @@ def _convert_to_object(data): return data +async def to_async_iterator(data): + for item in data: + yield item + + class MockAsyncClient: """Mocks the API Client.""" @@ -330,7 +365,7 @@ class MockAsyncClient: for event in events if isoparse(event['timestamp']) >= after_timestamp ] - return [_convert_to_object(event) for event in events] + return to_async_iterator([_convert_to_object(event) for event in events]) async def _append_event( self, @@ -496,6 +531,22 @@ async def test_get_session_with_after_timestamp_filter(): assert session.events[0].id == '456' +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_get_session_with_many_events(mock_api_client_instance): + mock_api_client_instance.session_dict['5'] = MOCK_SESSION_JSON_5 + mock_api_client_instance.event_dict['5'] = ( + copy.deepcopy(MOCK_EVENTS_JSON_5), + None, + ) + session_service = mock_vertex_ai_session_service() + session = await session_service.get_session( + app_name='123', user_id='user_with_many_events', session_id='5' + ) + assert session is not None + assert len(session.events) == MANY_EVENTS_COUNT + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_list_sessions(): @@ -524,7 +575,13 @@ async def test_list_sessions_all_users(): session_service = mock_vertex_ai_session_service() sessions = await session_service.list_sessions(app_name='123', user_id=None) assert len(sessions.sessions) == 5 - assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'} + assert {s.id for s in sessions.sessions} == { + '1', + '2', + '3', + 'page1', + 'page2', + } @pytest.mark.asyncio