diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 6628eb95..eac13367 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -25,6 +25,11 @@ import urllib.parse from dateutil import parser from google.genai.errors import ClientError +from tenacity import retry +from tenacity import retry_if_result +from tenacity import RetryError +from tenacity import stop_after_attempt +from tenacity import wait_exponential from typing_extensions import override from google import genai @@ -64,6 +69,20 @@ class VertexAiSessionService(BaseSessionService): self._location = location self._agent_engine_id = agent_engine_id + async def _get_session_api_response( + self, + reasoning_engine_id: str, + session_id: str, + api_client: genai.ApiClient, + ): + get_session_api_response = await api_client.async_request( + http_method='GET', + path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', + request_dict={}, + ) + get_session_api_response = _convert_api_response(get_session_api_response) + return get_session_api_response + @override async def create_session( self, @@ -95,66 +114,68 @@ class VertexAiSessionService(BaseSessionService): session_id = api_response['name'].split('/')[-3] operation_id = api_response['name'].split('/')[-1] - - max_retry_attempt = 5 - if _is_vertex_express_mode(self._project, self._location): # Express mode doesn't support LRO, so we need to poll # the session resource. # TODO: remove this once LRO polling is supported in Express mode. - for i in range(max_retry_attempt): + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=3), + retry=retry_if_result(lambda response: not response), + reraise=True, + ) + async def _poll_session_resource(): try: - await api_client.async_request( - http_method='GET', - path=( - f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}' - ), - request_dict={}, + return await self._get_session_api_response( + reasoning_engine_id, session_id, api_client ) - break - except ClientError as e: - logger.info('Polling for session %s: %s', session_id, e) - # Add slight exponential backoff to avoid excessive polling. - await asyncio.sleep(1 + 0.5 * i) - else: - raise TimeoutError('Session creation failed.') + except ClientError: + logger.info(f'Polling session resource') + return None + + try: + await _poll_session_resource() + except Exception as exc: + raise ValueError('Failed to create session.') from exc else: - lro_response = None - for _ in range(max_retry_attempt): + + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1, max=3), + retry=retry_if_result( + lambda response: not response.get('done', False), + ), + reraise=True, + ) + async def _poll_lro(): lro_response = await api_client.async_request( http_method='GET', path=f'operations/{operation_id}', request_dict={}, ) lro_response = _convert_api_response(lro_response) + return lro_response - if lro_response.get('done', None): - break - - await asyncio.sleep(1) - - if lro_response is None or not lro_response.get('done', None): + try: + await _poll_lro() + except RetryError as exc: raise TimeoutError( f'Timeout waiting for operation {operation_id} to complete.' - ) + ) from exc + except Exception as exc: + raise ValueError('Failed to create session.') from exc - # Get session resource - get_session_api_response = await api_client.async_request( - http_method='GET', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', - request_dict={}, + get_session_api_response = await self._get_session_api_response( + reasoning_engine_id, session_id, api_client ) - get_session_api_response = _convert_api_response(get_session_api_response) - - update_timestamp = isoparse( - get_session_api_response['updateTime'] - ).timestamp() session = Session( app_name=str(app_name), user_id=str(user_id), id=str(session_id), state=get_session_api_response.get('sessionState', {}), - last_update_time=update_timestamp, + last_update_time=isoparse( + get_session_api_response['updateTime'] + ).timestamp(), ) return session @@ -171,12 +192,9 @@ class VertexAiSessionService(BaseSessionService): api_client = self._get_api_client() # Get session resource - get_session_api_response = await api_client.async_request( - http_method='GET', - path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}', - request_dict={}, + get_session_api_response = await self._get_session_api_response( + reasoning_engine_id, session_id, api_client ) - get_session_api_response = _convert_api_response(get_session_api_response) if get_session_api_response['userId'] != user_id: raise ValueError(f'Session not found: {session_id}')