fix: Merge custom http options with adk specific http options in model api request

PiperOrigin-RevId: 770836112
This commit is contained in:
Xiang (Sean) Zhou
2025-06-12 16:19:42 -07:00
committed by Copybara-Service
parent d22920bd7f
commit 4ccda99e8e
2 changed files with 282 additions and 14 deletions
+33 -14
View File
@@ -95,6 +95,13 @@ class Gemini(BaseLlm):
)
logger.info(_build_request_log(llm_request))
# add tracking headers to custom headers given it will override the headers
# set in the api client constructor
if llm_request.config and llm_request.config.http_options:
if not llm_request.config.http_options.headers:
llm_request.config.http_options.headers = {}
llm_request.config.http_options.headers.update(self._tracking_headers)
if stream:
responses = await self.api_client.aio.models.generate_content_stream(
model=llm_request.model,
@@ -201,24 +208,21 @@ class Gemini(BaseLlm):
return tracking_headers
@cached_property
def _live_api_client(self) -> Client:
def _live_api_version(self) -> str:
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
# use beta version for vertex api
api_version = 'v1beta1'
# use default api version for vertex
return Client(
http_options=types.HttpOptions(
headers=self._tracking_headers, api_version=api_version
)
)
return 'v1beta1'
else:
# use v1alpha for using API KEY from Google AI Studio
api_version = 'v1alpha'
return Client(
http_options=types.HttpOptions(
headers=self._tracking_headers, api_version=api_version
)
)
return 'v1alpha'
@cached_property
def _live_api_client(self) -> Client:
return Client(
http_options=types.HttpOptions(
headers=self._tracking_headers, api_version=self._live_api_version
)
)
@contextlib.asynccontextmanager
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
@@ -230,6 +234,21 @@ class Gemini(BaseLlm):
Yields:
BaseLlmConnection, the connection to the Gemini model.
"""
# add tracking headers to custom headers and set api_version given
# the customized http options will override the one set in the api client
# constructor
if (
llm_request.live_connect_config
and llm_request.live_connect_config.http_options
):
if not llm_request.live_connect_config.http_options.headers:
llm_request.live_connect_config.http_options.headers = {}
llm_request.live_connect_config.http_options.headers.update(
self._tracking_headers
)
llm_request.live_connect_config.http_options.api_version = (
self._live_api_version
)
llm_request.live_connect_config.system_instruction = types.Content(
role='system',
+249
View File
@@ -341,6 +341,255 @@ async def test_connect(gemini_llm, llm_request):
assert connection is mock_connection
@pytest.mark.asyncio
async def test_generate_content_async_with_custom_headers(
gemini_llm, llm_request, generate_content_response
):
"""Test that tracking headers are updated when custom headers are provided."""
# Add custom headers to the request config
custom_headers = {"custom-header": "custom-value"}
for key in gemini_llm._tracking_headers:
custom_headers[key] = "custom " + gemini_llm._tracking_headers[key]
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
with mock.patch.object(gemini_llm, "api_client") as mock_client:
# Create a mock coroutine that returns the generate_content_response
async def mock_coro():
return generate_content_response
mock_client.aio.models.generate_content.return_value = mock_coro()
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=False
)
]
# Verify that the config passed to generate_content contains merged headers
mock_client.aio.models.generate_content.assert_called_once()
call_args = mock_client.aio.models.generate_content.call_args
config_arg = call_args.kwargs["config"]
for key, value in config_arg.http_options.headers.items():
if key in gemini_llm._tracking_headers:
assert value == gemini_llm._tracking_headers[key]
else:
assert value == custom_headers[key]
assert len(responses) == 1
assert isinstance(responses[0], LlmResponse)
@pytest.mark.asyncio
async def test_generate_content_async_stream_with_custom_headers(
gemini_llm, llm_request
):
"""Test that tracking headers are updated when custom headers are provided in streaming mode."""
# Add custom headers to the request config
custom_headers = {"custom-header": "custom-value"}
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
with mock.patch.object(gemini_llm, "api_client") as mock_client:
# Create mock stream responses
class MockAsyncIterator:
def __init__(self, seq):
self.iter = iter(seq)
def __aiter__(self):
return self
async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration
mock_responses = [
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model", parts=[Part.from_text(text="Hello")]
),
finish_reason=types.FinishReason.STOP,
)
]
)
]
async def mock_coro():
return MockAsyncIterator(mock_responses)
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=True
)
]
# Verify that the config passed to generate_content_stream contains merged headers
mock_client.aio.models.generate_content_stream.assert_called_once()
call_args = mock_client.aio.models.generate_content_stream.call_args
config_arg = call_args.kwargs["config"]
expected_headers = custom_headers.copy()
expected_headers.update(gemini_llm._tracking_headers)
assert config_arg.http_options.headers == expected_headers
assert len(responses) == 2
@pytest.mark.asyncio
async def test_generate_content_async_without_custom_headers(
gemini_llm, llm_request, generate_content_response
):
"""Test that tracking headers are not modified when no custom headers exist."""
# Ensure no http_options exist initially
llm_request.config.http_options = None
with mock.patch.object(gemini_llm, "api_client") as mock_client:
async def mock_coro():
return generate_content_response
mock_client.aio.models.generate_content.return_value = mock_coro()
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=False
)
]
# Verify that the config passed to generate_content has no http_options
mock_client.aio.models.generate_content.assert_called_once()
call_args = mock_client.aio.models.generate_content.call_args
config_arg = call_args.kwargs["config"]
assert config_arg.http_options is None
assert len(responses) == 1
def test_live_api_version_vertex_ai(gemini_llm):
"""Test that _live_api_version returns 'v1beta1' for Vertex AI backend."""
with mock.patch.object(
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
):
assert gemini_llm._live_api_version == "v1beta1"
def test_live_api_version_gemini_api(gemini_llm):
"""Test that _live_api_version returns 'v1alpha' for Gemini API backend."""
with mock.patch.object(
gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API
):
assert gemini_llm._live_api_version == "v1alpha"
def test_live_api_client_properties(gemini_llm):
"""Test that _live_api_client is properly configured with tracking headers and API version."""
with mock.patch.object(
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
):
client = gemini_llm._live_api_client
# Verify that the client has the correct headers and API version
http_options = client._api_client._http_options
assert http_options.api_version == "v1beta1"
# Check that tracking headers are included
tracking_headers = gemini_llm._tracking_headers
for key, value in tracking_headers.items():
assert key in http_options.headers
assert value in http_options.headers[key]
@pytest.mark.asyncio
async def test_connect_with_custom_headers(gemini_llm, llm_request):
"""Test that connect method updates tracking headers and API version when custom headers are provided."""
# Setup request with live connect config and custom headers
custom_headers = {"custom-live-header": "live-value"}
llm_request.live_connect_config = types.LiveConnectConfig(
http_options=types.HttpOptions(headers=custom_headers)
)
mock_live_session = mock.AsyncMock()
# Mock the _live_api_client to return a mock client
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
# Create a mock context manager
class MockLiveConnect:
async def __aenter__(self):
return mock_live_session
async def __aexit__(self, *args):
pass
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
async with gemini_llm.connect(llm_request) as connection:
# Verify that the connect method was called with the right config
mock_live_client.aio.live.connect.assert_called_once()
call_args = mock_live_client.aio.live.connect.call_args
config_arg = call_args.kwargs["config"]
# Verify that tracking headers were merged with custom headers
expected_headers = custom_headers.copy()
expected_headers.update(gemini_llm._tracking_headers)
assert config_arg.http_options.headers == expected_headers
# Verify that API version was set
assert config_arg.http_options.api_version == gemini_llm._live_api_version
# Verify that system instruction and tools were set
assert config_arg.system_instruction is not None
assert config_arg.tools == llm_request.config.tools
# Verify connection is properly wrapped
assert isinstance(connection, GeminiLlmConnection)
@pytest.mark.asyncio
async def test_connect_without_custom_headers(gemini_llm, llm_request):
"""Test that connect method works properly when no custom headers are provided."""
# Setup request with live connect config but no custom headers
llm_request.live_connect_config = types.LiveConnectConfig()
mock_live_session = mock.AsyncMock()
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
class MockLiveConnect:
async def __aenter__(self):
return mock_live_session
async def __aexit__(self, *args):
pass
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
async with gemini_llm.connect(llm_request) as connection:
# Verify that the connect method was called with the right config
mock_live_client.aio.live.connect.assert_called_once()
call_args = mock_live_client.aio.live.connect.call_args
config_arg = call_args.kwargs["config"]
# Verify that http_options remains None since no custom headers were provided
assert config_arg.http_options is None
# Verify that system instruction and tools were still set
assert config_arg.system_instruction is not None
assert config_arg.tools == llm_request.config.tools
assert isinstance(connection, GeminiLlmConnection)
@pytest.mark.parametrize(
(
"api_backend, "