You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Use async for to loop through event iterator to get all events in vertex_ai_session_service
Fix https://github.com/google/adk-python/issues/3559 Co-authored-by: Shangjie Chen <deanchen@google.com> PiperOrigin-RevId: 832476367
This commit is contained in:
committed by
Copybara-Service
parent
a754c96d3c
commit
9211f4ce8c
@@ -178,7 +178,7 @@ class VertexAiSessionService(BaseSessionService):
|
||||
)
|
||||
session.events += [
|
||||
_from_api_event(event)
|
||||
for event in events_iterator
|
||||
async for event in events_iterator
|
||||
if event.timestamp.timestamp() <= update_timestamp
|
||||
]
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
import types
|
||||
from typing import Any
|
||||
@@ -130,6 +130,36 @@ MOCK_SESSION_JSON_PAGE2 = {
|
||||
'user_id': 'user_with_pages',
|
||||
}
|
||||
|
||||
MOCK_SESSION_JSON_5 = {
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
'reasoningEngines/123/sessions/5'
|
||||
),
|
||||
'update_time': '2024-12-12T12:15:12.123456Z',
|
||||
'user_id': 'user_with_many_events',
|
||||
}
|
||||
|
||||
|
||||
def _generate_mock_events_for_session_5(num_events):
|
||||
events = []
|
||||
start_time = isoparse('2024-12-12T12:12:12.123456Z')
|
||||
for i in range(num_events):
|
||||
event_time = start_time + datetime.timedelta(microseconds=i * 1000)
|
||||
events.append({
|
||||
'name': (
|
||||
'projects/test-project/locations/test-location/'
|
||||
f'reasoningEngines/123/sessions/5/events/{i}'
|
||||
),
|
||||
'invocation_id': f'invocation_{i}',
|
||||
'author': 'user_with_many_events',
|
||||
'timestamp': event_time.isoformat().replace('+00:00', 'Z'),
|
||||
})
|
||||
return events
|
||||
|
||||
|
||||
MANY_EVENTS_COUNT = 200
|
||||
MOCK_EVENTS_JSON_5 = _generate_mock_events_for_session_5(MANY_EVENTS_COUNT)
|
||||
|
||||
MOCK_SESSION = Session(
|
||||
app_name='123',
|
||||
user_id='user',
|
||||
@@ -228,6 +258,11 @@ def _convert_to_object(data):
|
||||
return data
|
||||
|
||||
|
||||
async def to_async_iterator(data):
|
||||
for item in data:
|
||||
yield item
|
||||
|
||||
|
||||
class MockAsyncClient:
|
||||
"""Mocks the API Client."""
|
||||
|
||||
@@ -330,7 +365,7 @@ class MockAsyncClient:
|
||||
for event in events
|
||||
if isoparse(event['timestamp']) >= after_timestamp
|
||||
]
|
||||
return [_convert_to_object(event) for event in events]
|
||||
return to_async_iterator([_convert_to_object(event) for event in events])
|
||||
|
||||
async def _append_event(
|
||||
self,
|
||||
@@ -496,6 +531,22 @@ async def test_get_session_with_after_timestamp_filter():
|
||||
assert session.events[0].id == '456'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_get_session_with_many_events(mock_api_client_instance):
|
||||
mock_api_client_instance.session_dict['5'] = MOCK_SESSION_JSON_5
|
||||
mock_api_client_instance.event_dict['5'] = (
|
||||
copy.deepcopy(MOCK_EVENTS_JSON_5),
|
||||
None,
|
||||
)
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
session = await session_service.get_session(
|
||||
app_name='123', user_id='user_with_many_events', session_id='5'
|
||||
)
|
||||
assert session is not None
|
||||
assert len(session.events) == MANY_EVENTS_COUNT
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures('mock_get_api_client')
|
||||
async def test_list_sessions():
|
||||
@@ -524,7 +575,13 @@ async def test_list_sessions_all_users():
|
||||
session_service = mock_vertex_ai_session_service()
|
||||
sessions = await session_service.list_sessions(app_name='123', user_id=None)
|
||||
assert len(sessions.sessions) == 5
|
||||
assert {s.id for s in sessions.sessions} == {'1', '2', '3', 'page1', 'page2'}
|
||||
assert {s.id for s in sessions.sessions} == {
|
||||
'1',
|
||||
'2',
|
||||
'3',
|
||||
'page1',
|
||||
'page2',
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user