You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Update the retry logic of create session polling
This should slightly increase the timeout also reduce the polling frequency. PiperOrigin-RevId: 778323416
This commit is contained in:
committed by
Copybara-Service
parent
9af2394e0a
commit
3d2f13cecd
@@ -25,6 +25,11 @@ import urllib.parse
|
||||
|
||||
from dateutil import parser
|
||||
from google.genai.errors import ClientError
|
||||
from tenacity import retry
|
||||
from tenacity import retry_if_result
|
||||
from tenacity import RetryError
|
||||
from tenacity import stop_after_attempt
|
||||
from tenacity import wait_exponential
|
||||
from typing_extensions import override
|
||||
|
||||
from google import genai
|
||||
@@ -64,6 +69,20 @@ class VertexAiSessionService(BaseSessionService):
|
||||
self._location = location
|
||||
self._agent_engine_id = agent_engine_id
|
||||
|
||||
async def _get_session_api_response(
|
||||
self,
|
||||
reasoning_engine_id: str,
|
||||
session_id: str,
|
||||
api_client: genai.ApiClient,
|
||||
):
|
||||
get_session_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
)
|
||||
get_session_api_response = _convert_api_response(get_session_api_response)
|
||||
return get_session_api_response
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
self,
|
||||
@@ -95,66 +114,68 @@ class VertexAiSessionService(BaseSessionService):
|
||||
|
||||
session_id = api_response['name'].split('/')[-3]
|
||||
operation_id = api_response['name'].split('/')[-1]
|
||||
|
||||
max_retry_attempt = 5
|
||||
|
||||
if _is_vertex_express_mode(self._project, self._location):
|
||||
# Express mode doesn't support LRO, so we need to poll
|
||||
# the session resource.
|
||||
# TODO: remove this once LRO polling is supported in Express mode.
|
||||
for i in range(max_retry_attempt):
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
retry=retry_if_result(lambda response: not response),
|
||||
reraise=True,
|
||||
)
|
||||
async def _poll_session_resource():
|
||||
try:
|
||||
await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=(
|
||||
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
|
||||
),
|
||||
request_dict={},
|
||||
return await self._get_session_api_response(
|
||||
reasoning_engine_id, session_id, api_client
|
||||
)
|
||||
break
|
||||
except ClientError as e:
|
||||
logger.info('Polling for session %s: %s', session_id, e)
|
||||
# Add slight exponential backoff to avoid excessive polling.
|
||||
await asyncio.sleep(1 + 0.5 * i)
|
||||
else:
|
||||
raise TimeoutError('Session creation failed.')
|
||||
except ClientError:
|
||||
logger.info(f'Polling session resource')
|
||||
return None
|
||||
|
||||
try:
|
||||
await _poll_session_resource()
|
||||
except Exception as exc:
|
||||
raise ValueError('Failed to create session.') from exc
|
||||
else:
|
||||
lro_response = None
|
||||
for _ in range(max_retry_attempt):
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=3),
|
||||
retry=retry_if_result(
|
||||
lambda response: not response.get('done', False),
|
||||
),
|
||||
reraise=True,
|
||||
)
|
||||
async def _poll_lro():
|
||||
lro_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'operations/{operation_id}',
|
||||
request_dict={},
|
||||
)
|
||||
lro_response = _convert_api_response(lro_response)
|
||||
return lro_response
|
||||
|
||||
if lro_response.get('done', None):
|
||||
break
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if lro_response is None or not lro_response.get('done', None):
|
||||
try:
|
||||
await _poll_lro()
|
||||
except RetryError as exc:
|
||||
raise TimeoutError(
|
||||
f'Timeout waiting for operation {operation_id} to complete.'
|
||||
)
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
raise ValueError('Failed to create session.') from exc
|
||||
|
||||
# Get session resource
|
||||
get_session_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
get_session_api_response = await self._get_session_api_response(
|
||||
reasoning_engine_id, session_id, api_client
|
||||
)
|
||||
get_session_api_response = _convert_api_response(get_session_api_response)
|
||||
|
||||
update_timestamp = isoparse(
|
||||
get_session_api_response['updateTime']
|
||||
).timestamp()
|
||||
session = Session(
|
||||
app_name=str(app_name),
|
||||
user_id=str(user_id),
|
||||
id=str(session_id),
|
||||
state=get_session_api_response.get('sessionState', {}),
|
||||
last_update_time=update_timestamp,
|
||||
last_update_time=isoparse(
|
||||
get_session_api_response['updateTime']
|
||||
).timestamp(),
|
||||
)
|
||||
return session
|
||||
|
||||
@@ -171,12 +192,9 @@ class VertexAiSessionService(BaseSessionService):
|
||||
api_client = self._get_api_client()
|
||||
|
||||
# Get session resource
|
||||
get_session_api_response = await api_client.async_request(
|
||||
http_method='GET',
|
||||
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
|
||||
request_dict={},
|
||||
get_session_api_response = await self._get_session_api_response(
|
||||
reasoning_engine_id, session_id, api_client
|
||||
)
|
||||
get_session_api_response = _convert_api_response(get_session_api_response)
|
||||
|
||||
if get_session_api_response['userId'] != user_id:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
|
||||
Reference in New Issue
Block a user