diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index beae164a..1bdd3111 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING from typing import Union from google.genai import types +from google.genai.errors import ClientError from typing_extensions import override from .. import version @@ -51,6 +52,34 @@ _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} _AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID' +_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ +On how to mitigate this issue, please refer to: + +https://google.github.io/adk-docs/agents/models/#error-code-429-resource_exhausted +""" + + +class _ResourceExhaustedError(ClientError): + """Represents an resources exhausted error received from the Model.""" + + def __init__( + self, + client_error: ClientError, + ): + super().__init__( + code=client_error.code, + response_json=client_error.details, + response=client_error.response, + ) + + def __str__(self): + # We don't get override the actual message on ClientError, so we override + # this method instead. This will ensure that when the exception is + # stringified (for either publishing the exception on console or to logs) + # we put in the required details for the developer. + base_message = super().__str__() + return f'{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}' + class Gemini(BaseLlm): """Integration for Gemini models. @@ -149,50 +178,61 @@ class Gemini(BaseLlm): llm_request.config.http_options.headers ) - if stream: - responses = await self.api_client.aio.models.generate_content_stream( - model=llm_request.model, - contents=llm_request.contents, - config=llm_request.config, - ) + try: + if stream: + responses = await self.api_client.aio.models.generate_content_stream( + model=llm_request.model, + contents=llm_request.contents, + config=llm_request.config, + ) - # for sse, similar as bidi (see receive method in gemini_llm_connection.py), - # we need to mark those text content as partial and after all partial - # contents are sent, we send an accumulated event which contains all the - # previous partial content. The only difference is bidi rely on - # complete_turn flag to detect end while sse depends on finish_reason. - aggregator = StreamingResponseAggregator() - async with Aclosing(responses) as agen: - async for response in agen: - logger.debug(_build_response_log(response)) - async with Aclosing( - aggregator.process_response(response) - ) as aggregator_gen: - async for llm_response in aggregator_gen: - yield llm_response - if (close_result := aggregator.close()) is not None: - # Populate cache metadata in the final aggregated response for streaming + # for sse, similar as bidi (see receive method in + # gemini_llm_connection.py), we need to mark those text content as + # partial and after all partial contents are sent, we send an + # accumulated event which contains all the previous partial content. The + # only difference is bidi rely on complete_turn flag to detect end while + # sse depends on finish_reason. + aggregator = StreamingResponseAggregator() + async with Aclosing(responses) as agen: + async for response in agen: + logger.debug(_build_response_log(response)) + async with Aclosing( + aggregator.process_response(response) + ) as aggregator_gen: + async for llm_response in aggregator_gen: + yield llm_response + if (close_result := aggregator.close()) is not None: + # Populate cache metadata in the final aggregated response for + # streaming + if cache_metadata: + cache_manager.populate_cache_metadata_in_response( + close_result, cache_metadata + ) + yield close_result + + else: + response = await self.api_client.aio.models.generate_content( + model=llm_request.model, + contents=llm_request.contents, + config=llm_request.config, + ) + logger.info('Response received from the model.') + logger.debug(_build_response_log(response)) + + llm_response = LlmResponse.create(response) if cache_metadata: cache_manager.populate_cache_metadata_in_response( - close_result, cache_metadata + llm_response, cache_metadata ) - yield close_result + yield llm_response + except ClientError as ce: + if ce.code == 429: + # We expect running into a Resource Exhausted error to be a common + # client error that developers would run into. We enhance the messaging + # with possible fixes to this issue. + raise _ResourceExhaustedError(ce) from ce - else: - response = await self.api_client.aio.models.generate_content( - model=llm_request.model, - contents=llm_request.contents, - config=llm_request.config, - ) - logger.info('Response received from the model.') - logger.debug(_build_response_log(response)) - - llm_response = LlmResponse.create(response) - if cache_metadata: - cache_manager.populate_cache_metadata_in_response( - llm_response, cache_metadata - ) - yield llm_response + raise ce @cached_property def api_client(self) -> Client: diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index ed5d0335..f2419daf 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -26,11 +26,14 @@ from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NA from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.models.google_llm import _build_function_declaration_log from google.adk.models.google_llm import _build_request_log +from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE +from google.adk.models.google_llm import _ResourceExhaustedError from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types +from google.genai.errors import ClientError from google.genai.types import Content from google.genai.types import Part import pytest @@ -386,6 +389,60 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts( mock_client.aio.models.generate_content_stream.assert_called_once() +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_generate_content_async_resource_exhausted_error( + stream, gemini_llm, llm_request +): + with mock.patch.object(gemini_llm, "api_client") as mock_client: + err = ClientError(code=429, response_json={}) + err.code = 429 + if stream: + mock_client.aio.models.generate_content_stream.side_effect = err + else: + mock_client.aio.models.generate_content.side_effect = err + + with pytest.raises(_ResourceExhaustedError) as excinfo: + responses = [] + async for resp in gemini_llm.generate_content_async( + llm_request, stream=stream + ): + responses.append(resp) + assert _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE in str(excinfo.value) + assert excinfo.value.code == 429 + if stream: + mock_client.aio.models.generate_content_stream.assert_called_once() + else: + mock_client.aio.models.generate_content.assert_called_once() + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_generate_content_async_other_client_error( + stream, gemini_llm, llm_request +): + with mock.patch.object(gemini_llm, "api_client") as mock_client: + err = ClientError(code=500, response_json={}) + err.code = 500 + if stream: + mock_client.aio.models.generate_content_stream.side_effect = err + else: + mock_client.aio.models.generate_content.side_effect = err + + with pytest.raises(ClientError) as excinfo: + responses = [] + async for resp in gemini_llm.generate_content_async( + llm_request, stream=stream + ): + responses.append(resp) + assert excinfo.value.code == 500 + assert not isinstance(excinfo.value, _ResourceExhaustedError) + if stream: + mock_client.aio.models.generate_content_stream.assert_called_once() + else: + mock_client.aio.models.generate_content.assert_called_once() + + @pytest.mark.asyncio async def test_connect(gemini_llm, llm_request): # Create a mock connection