chore: Update the retry logic of create session polling

This should slightly increase the timeout also reduce the polling frequency.

PiperOrigin-RevId: 778323416
This commit is contained in:
Shangjie Chen
2025-07-01 21:21:17 -07:00
committed by Copybara-Service
parent 9af2394e0a
commit 3d2f13cecd
@@ -25,6 +25,11 @@ import urllib.parse
from dateutil import parser
from google.genai.errors import ClientError
from tenacity import retry
from tenacity import retry_if_result
from tenacity import RetryError
from tenacity import stop_after_attempt
from tenacity import wait_exponential
from typing_extensions import override
from google import genai
@@ -64,6 +69,20 @@ class VertexAiSessionService(BaseSessionService):
self._location = location
self._agent_engine_id = agent_engine_id
async def _get_session_api_response(
self,
reasoning_engine_id: str,
session_id: str,
api_client: genai.ApiClient,
):
get_session_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
)
get_session_api_response = _convert_api_response(get_session_api_response)
return get_session_api_response
@override
async def create_session(
self,
@@ -95,66 +114,68 @@ class VertexAiSessionService(BaseSessionService):
session_id = api_response['name'].split('/')[-3]
operation_id = api_response['name'].split('/')[-1]
max_retry_attempt = 5
if _is_vertex_express_mode(self._project, self._location):
# Express mode doesn't support LRO, so we need to poll
# the session resource.
# TODO: remove this once LRO polling is supported in Express mode.
for i in range(max_retry_attempt):
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=3),
retry=retry_if_result(lambda response: not response),
reraise=True,
)
async def _poll_session_resource():
try:
await api_client.async_request(
http_method='GET',
path=(
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
),
request_dict={},
return await self._get_session_api_response(
reasoning_engine_id, session_id, api_client
)
break
except ClientError as e:
logger.info('Polling for session %s: %s', session_id, e)
# Add slight exponential backoff to avoid excessive polling.
await asyncio.sleep(1 + 0.5 * i)
else:
raise TimeoutError('Session creation failed.')
except ClientError:
logger.info(f'Polling session resource')
return None
try:
await _poll_session_resource()
except Exception as exc:
raise ValueError('Failed to create session.') from exc
else:
lro_response = None
for _ in range(max_retry_attempt):
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=3),
retry=retry_if_result(
lambda response: not response.get('done', False),
),
reraise=True,
)
async def _poll_lro():
lro_response = await api_client.async_request(
http_method='GET',
path=f'operations/{operation_id}',
request_dict={},
)
lro_response = _convert_api_response(lro_response)
return lro_response
if lro_response.get('done', None):
break
await asyncio.sleep(1)
if lro_response is None or not lro_response.get('done', None):
try:
await _poll_lro()
except RetryError as exc:
raise TimeoutError(
f'Timeout waiting for operation {operation_id} to complete.'
)
) from exc
except Exception as exc:
raise ValueError('Failed to create session.') from exc
# Get session resource
get_session_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
get_session_api_response = await self._get_session_api_response(
reasoning_engine_id, session_id, api_client
)
get_session_api_response = _convert_api_response(get_session_api_response)
update_timestamp = isoparse(
get_session_api_response['updateTime']
).timestamp()
session = Session(
app_name=str(app_name),
user_id=str(user_id),
id=str(session_id),
state=get_session_api_response.get('sessionState', {}),
last_update_time=update_timestamp,
last_update_time=isoparse(
get_session_api_response['updateTime']
).timestamp(),
)
return session
@@ -171,12 +192,9 @@ class VertexAiSessionService(BaseSessionService):
api_client = self._get_api_client()
# Get session resource
get_session_api_response = await api_client.async_request(
http_method='GET',
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
request_dict={},
get_session_api_response = await self._get_session_api_response(
reasoning_engine_id, session_id, api_client
)
get_session_api_response = _convert_api_response(get_session_api_response)
if get_session_api_response['userId'] != user_id:
raise ValueError(f'Session not found: {session_id}')