You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Enhance the messaging with possible fixes for RESOURCE_EXHAUSTED errors from Gemini
Co-authored-by: Ankur Sharma <ankusharma@google.com> PiperOrigin-RevId: 833538475
This commit is contained in:
committed by
Copybara-Service
parent
5ac5129fb0
commit
b2c45f8d91
@@ -28,6 +28,7 @@ from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from google.genai.errors import ClientError
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
@@ -51,6 +52,34 @@ _EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
||||
_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine'
|
||||
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'
|
||||
|
||||
_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """
|
||||
On how to mitigate this issue, please refer to:
|
||||
|
||||
https://google.github.io/adk-docs/agents/models/#error-code-429-resource_exhausted
|
||||
"""
|
||||
|
||||
|
||||
class _ResourceExhaustedError(ClientError):
|
||||
"""Represents an resources exhausted error received from the Model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_error: ClientError,
|
||||
):
|
||||
super().__init__(
|
||||
code=client_error.code,
|
||||
response_json=client_error.details,
|
||||
response=client_error.response,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
# We don't get override the actual message on ClientError, so we override
|
||||
# this method instead. This will ensure that when the exception is
|
||||
# stringified (for either publishing the exception on console or to logs)
|
||||
# we put in the required details for the developer.
|
||||
base_message = super().__str__()
|
||||
return f'{_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE}\n\n{base_message}'
|
||||
|
||||
|
||||
class Gemini(BaseLlm):
|
||||
"""Integration for Gemini models.
|
||||
@@ -149,50 +178,61 @@ class Gemini(BaseLlm):
|
||||
llm_request.config.http_options.headers
|
||||
)
|
||||
|
||||
if stream:
|
||||
responses = await self.api_client.aio.models.generate_content_stream(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
try:
|
||||
if stream:
|
||||
responses = await self.api_client.aio.models.generate_content_stream(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
|
||||
# for sse, similar as bidi (see receive method in gemini_llm_connection.py),
|
||||
# we need to mark those text content as partial and after all partial
|
||||
# contents are sent, we send an accumulated event which contains all the
|
||||
# previous partial content. The only difference is bidi rely on
|
||||
# complete_turn flag to detect end while sse depends on finish_reason.
|
||||
aggregator = StreamingResponseAggregator()
|
||||
async with Aclosing(responses) as agen:
|
||||
async for response in agen:
|
||||
logger.debug(_build_response_log(response))
|
||||
async with Aclosing(
|
||||
aggregator.process_response(response)
|
||||
) as aggregator_gen:
|
||||
async for llm_response in aggregator_gen:
|
||||
yield llm_response
|
||||
if (close_result := aggregator.close()) is not None:
|
||||
# Populate cache metadata in the final aggregated response for streaming
|
||||
# for sse, similar as bidi (see receive method in
|
||||
# gemini_llm_connection.py), we need to mark those text content as
|
||||
# partial and after all partial contents are sent, we send an
|
||||
# accumulated event which contains all the previous partial content. The
|
||||
# only difference is bidi rely on complete_turn flag to detect end while
|
||||
# sse depends on finish_reason.
|
||||
aggregator = StreamingResponseAggregator()
|
||||
async with Aclosing(responses) as agen:
|
||||
async for response in agen:
|
||||
logger.debug(_build_response_log(response))
|
||||
async with Aclosing(
|
||||
aggregator.process_response(response)
|
||||
) as aggregator_gen:
|
||||
async for llm_response in aggregator_gen:
|
||||
yield llm_response
|
||||
if (close_result := aggregator.close()) is not None:
|
||||
# Populate cache metadata in the final aggregated response for
|
||||
# streaming
|
||||
if cache_metadata:
|
||||
cache_manager.populate_cache_metadata_in_response(
|
||||
close_result, cache_metadata
|
||||
)
|
||||
yield close_result
|
||||
|
||||
else:
|
||||
response = await self.api_client.aio.models.generate_content(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
logger.info('Response received from the model.')
|
||||
logger.debug(_build_response_log(response))
|
||||
|
||||
llm_response = LlmResponse.create(response)
|
||||
if cache_metadata:
|
||||
cache_manager.populate_cache_metadata_in_response(
|
||||
close_result, cache_metadata
|
||||
llm_response, cache_metadata
|
||||
)
|
||||
yield close_result
|
||||
yield llm_response
|
||||
except ClientError as ce:
|
||||
if ce.code == 429:
|
||||
# We expect running into a Resource Exhausted error to be a common
|
||||
# client error that developers would run into. We enhance the messaging
|
||||
# with possible fixes to this issue.
|
||||
raise _ResourceExhaustedError(ce) from ce
|
||||
|
||||
else:
|
||||
response = await self.api_client.aio.models.generate_content(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
logger.info('Response received from the model.')
|
||||
logger.debug(_build_response_log(response))
|
||||
|
||||
llm_response = LlmResponse.create(response)
|
||||
if cache_metadata:
|
||||
cache_manager.populate_cache_metadata_in_response(
|
||||
llm_response, cache_metadata
|
||||
)
|
||||
yield llm_response
|
||||
raise ce
|
||||
|
||||
@cached_property
|
||||
def api_client(self) -> Client:
|
||||
|
||||
@@ -26,11 +26,14 @@ from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NA
|
||||
from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG
|
||||
from google.adk.models.google_llm import _build_function_declaration_log
|
||||
from google.adk.models.google_llm import _build_request_log
|
||||
from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE
|
||||
from google.adk.models.google_llm import _ResourceExhaustedError
|
||||
from google.adk.models.google_llm import Gemini
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
from google.genai import types
|
||||
from google.genai.errors import ClientError
|
||||
from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
@@ -386,6 +389,60 @@ async def test_generate_content_async_stream_preserves_thinking_and_text_parts(
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_resource_exhausted_error(
|
||||
stream, gemini_llm, llm_request
|
||||
):
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
err = ClientError(code=429, response_json={})
|
||||
err.code = 429
|
||||
if stream:
|
||||
mock_client.aio.models.generate_content_stream.side_effect = err
|
||||
else:
|
||||
mock_client.aio.models.generate_content.side_effect = err
|
||||
|
||||
with pytest.raises(_ResourceExhaustedError) as excinfo:
|
||||
responses = []
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=stream
|
||||
):
|
||||
responses.append(resp)
|
||||
assert _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE in str(excinfo.value)
|
||||
assert excinfo.value.code == 429
|
||||
if stream:
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
else:
|
||||
mock_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_other_client_error(
|
||||
stream, gemini_llm, llm_request
|
||||
):
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
err = ClientError(code=500, response_json={})
|
||||
err.code = 500
|
||||
if stream:
|
||||
mock_client.aio.models.generate_content_stream.side_effect = err
|
||||
else:
|
||||
mock_client.aio.models.generate_content.side_effect = err
|
||||
|
||||
with pytest.raises(ClientError) as excinfo:
|
||||
responses = []
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=stream
|
||||
):
|
||||
responses.append(resp)
|
||||
assert excinfo.value.code == 500
|
||||
assert not isinstance(excinfo.value, _ResourceExhaustedError)
|
||||
if stream:
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
else:
|
||||
mock_client.aio.models.generate_content.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(gemini_llm, llm_request):
|
||||
# Create a mock connection
|
||||
|
||||
Reference in New Issue
Block a user