fix: Use async for to loop through event iterator to get all events in vertex_ai_session_service

Fix https://github.com/google/adk-python/issues/3559

Co-authored-by: Shangjie Chen <deanchen@google.com>
PiperOrigin-RevId: 832476367
This commit is contained in:
Shangjie Chen
2025-11-14 15:21:56 -08:00
committed by Copybara-Service
parent a754c96d3c
commit 9211f4ce8c
2 changed files with 61 additions and 4 deletions
@@ -178,7 +178,7 @@ class VertexAiSessionService(BaseSessionService):
)
session.events += [
_from_api_event(event)
for event in events_iterator
async for event in events_iterator
if event.timestamp.timestamp() <= update_timestamp
]
@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import datetime
import re
import types
from typing import Any
@@ -130,6 +130,36 @@ MOCK_SESSION_JSON_PAGE2 = {
'user_id': 'user_with_pages',
}
MOCK_SESSION_JSON_5 = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/5'
),
'update_time': '2024-12-12T12:15:12.123456Z',
'user_id': 'user_with_many_events',
}
def _generate_mock_events_for_session_5(num_events):
events = []
start_time = isoparse('2024-12-12T12:12:12.123456Z')
for i in range(num_events):
event_time = start_time + datetime.timedelta(microseconds=i * 1000)
events.append({
'name': (
'projects/test-project/locations/test-location/'
f'reasoningEngines/123/sessions/5/events/{i}'
),
'invocation_id': f'invocation_{i}',
'author': 'user_with_many_events',
'timestamp': event_time.isoformat().replace('+00:00', 'Z'),
})
return events
MANY_EVENTS_COUNT = 200
MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT)
MOCK_SESSION = Session(
app_name='123',
user_id='user',
@@ -228,6 +258,11 @@ def _convert_to_object(data):
return data
async def to_async_iterator(data):
for item in data:
yield item
class MockAsyncClient:
"""Mocks the API Client."""
@@ -330,7 +365,7 @@ class MockAsyncClient:
for event in events
if isoparse(event['timestamp']) >= after_timestamp
]
return [_convert_to_object(event) for event in events]
return to_async_iterator([_convert_to_object(event) for event in events])
async def _append_event(
self,
@@ -496,6 +531,22 @@ async def test_get_session_with_after_timestamp_filter():
assert session.events[0].id == '456'
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_get_session_with_many_events(mock_api_client_instance):
mock_api_client_instance.session_dict['5'] = MOCK_SESSION_JSON_5
mock_api_client_instance.event_dict['5'] = (
copy.deepcopy(MOCK_EVENTS_JSON_5),
None,
)
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user_with_many_events', session_id='5'
)
assert session is not None
assert len(session.events) == MANY_EVENTS_COUNT
@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_list_sessions():
@@ -524,7 +575,13 @@ async def test_list_sessions_all_users():
session_service = mock_vertex_ai_session_service()
sessions = await session_service.list_sessions(app_name='123', user_id=None)
assert len(sessions.sessions) == 5
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
assert {s.id for s in sessions.sessions} == {
'1',
'2',
'3',
'page1',
'page2',
}
@pytest.mark.asyncio