You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Merge custom http options with adk specific http options in model api request
PiperOrigin-RevId: 770836112
This commit is contained in:
committed by
Copybara-Service
parent
d22920bd7f
commit
4ccda99e8e
@@ -95,6 +95,13 @@ class Gemini(BaseLlm):
|
||||
)
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
# add tracking headers to custom headers given it will override the headers
|
||||
# set in the api client constructor
|
||||
if llm_request.config and llm_request.config.http_options:
|
||||
if not llm_request.config.http_options.headers:
|
||||
llm_request.config.http_options.headers = {}
|
||||
llm_request.config.http_options.headers.update(self._tracking_headers)
|
||||
|
||||
if stream:
|
||||
responses = await self.api_client.aio.models.generate_content_stream(
|
||||
model=llm_request.model,
|
||||
@@ -201,24 +208,21 @@ class Gemini(BaseLlm):
|
||||
return tracking_headers
|
||||
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
def _live_api_version(self) -> str:
|
||||
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
|
||||
# use beta version for vertex api
|
||||
api_version = 'v1beta1'
|
||||
# use default api version for vertex
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=api_version
|
||||
)
|
||||
)
|
||||
return 'v1beta1'
|
||||
else:
|
||||
# use v1alpha for using API KEY from Google AI Studio
|
||||
api_version = 'v1alpha'
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=api_version
|
||||
)
|
||||
)
|
||||
return 'v1alpha'
|
||||
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=self._live_api_version
|
||||
)
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
@@ -230,6 +234,21 @@ class Gemini(BaseLlm):
|
||||
Yields:
|
||||
BaseLlmConnection, the connection to the Gemini model.
|
||||
"""
|
||||
# add tracking headers to custom headers and set api_version given
|
||||
# the customized http options will override the one set in the api client
|
||||
# constructor
|
||||
if (
|
||||
llm_request.live_connect_config
|
||||
and llm_request.live_connect_config.http_options
|
||||
):
|
||||
if not llm_request.live_connect_config.http_options.headers:
|
||||
llm_request.live_connect_config.http_options.headers = {}
|
||||
llm_request.live_connect_config.http_options.headers.update(
|
||||
self._tracking_headers
|
||||
)
|
||||
llm_request.live_connect_config.http_options.api_version = (
|
||||
self._live_api_version
|
||||
)
|
||||
|
||||
llm_request.live_connect_config.system_instruction = types.Content(
|
||||
role='system',
|
||||
|
||||
@@ -341,6 +341,255 @@ async def test_connect(gemini_llm, llm_request):
|
||||
assert connection is mock_connection
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_custom_headers(
|
||||
gemini_llm, llm_request, generate_content_response
|
||||
):
|
||||
"""Test that tracking headers are updated when custom headers are provided."""
|
||||
# Add custom headers to the request config
|
||||
custom_headers = {"custom-header": "custom-value"}
|
||||
for key in gemini_llm._tracking_headers:
|
||||
custom_headers[key] = "custom " + gemini_llm._tracking_headers[key]
|
||||
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
|
||||
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
# Create a mock coroutine that returns the generate_content_response
|
||||
async def mock_coro():
|
||||
return generate_content_response
|
||||
|
||||
mock_client.aio.models.generate_content.return_value = mock_coro()
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=False
|
||||
)
|
||||
]
|
||||
|
||||
# Verify that the config passed to generate_content contains merged headers
|
||||
mock_client.aio.models.generate_content.assert_called_once()
|
||||
call_args = mock_client.aio.models.generate_content.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
for key, value in config_arg.http_options.headers.items():
|
||||
if key in gemini_llm._tracking_headers:
|
||||
assert value == gemini_llm._tracking_headers[key]
|
||||
else:
|
||||
assert value == custom_headers[key]
|
||||
|
||||
assert len(responses) == 1
|
||||
assert isinstance(responses[0], LlmResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream_with_custom_headers(
|
||||
gemini_llm, llm_request
|
||||
):
|
||||
"""Test that tracking headers are updated when custom headers are provided in streaming mode."""
|
||||
# Add custom headers to the request config
|
||||
custom_headers = {"custom-header": "custom-value"}
|
||||
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
|
||||
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
# Create mock stream responses
|
||||
class MockAsyncIterator:
|
||||
|
||||
def __init__(self, seq):
|
||||
self.iter = iter(seq)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self.iter)
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
mock_responses = [
|
||||
types.GenerateContentResponse(
|
||||
candidates=[
|
||||
types.Candidate(
|
||||
content=Content(
|
||||
role="model", parts=[Part.from_text(text="Hello")]
|
||||
),
|
||||
finish_reason=types.FinishReason.STOP,
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
async def mock_coro():
|
||||
return MockAsyncIterator(mock_responses)
|
||||
|
||||
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=True
|
||||
)
|
||||
]
|
||||
|
||||
# Verify that the config passed to generate_content_stream contains merged headers
|
||||
mock_client.aio.models.generate_content_stream.assert_called_once()
|
||||
call_args = mock_client.aio.models.generate_content_stream.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
expected_headers = custom_headers.copy()
|
||||
expected_headers.update(gemini_llm._tracking_headers)
|
||||
assert config_arg.http_options.headers == expected_headers
|
||||
|
||||
assert len(responses) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_without_custom_headers(
|
||||
gemini_llm, llm_request, generate_content_response
|
||||
):
|
||||
"""Test that tracking headers are not modified when no custom headers exist."""
|
||||
# Ensure no http_options exist initially
|
||||
llm_request.config.http_options = None
|
||||
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
|
||||
async def mock_coro():
|
||||
return generate_content_response
|
||||
|
||||
mock_client.aio.models.generate_content.return_value = mock_coro()
|
||||
|
||||
responses = [
|
||||
resp
|
||||
async for resp in gemini_llm.generate_content_async(
|
||||
llm_request, stream=False
|
||||
)
|
||||
]
|
||||
|
||||
# Verify that the config passed to generate_content has no http_options
|
||||
mock_client.aio.models.generate_content.assert_called_once()
|
||||
call_args = mock_client.aio.models.generate_content.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
assert config_arg.http_options is None
|
||||
|
||||
assert len(responses) == 1
|
||||
|
||||
|
||||
def test_live_api_version_vertex_ai(gemini_llm):
|
||||
"""Test that _live_api_version returns 'v1beta1' for Vertex AI backend."""
|
||||
with mock.patch.object(
|
||||
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
|
||||
):
|
||||
assert gemini_llm._live_api_version == "v1beta1"
|
||||
|
||||
|
||||
def test_live_api_version_gemini_api(gemini_llm):
|
||||
"""Test that _live_api_version returns 'v1alpha' for Gemini API backend."""
|
||||
with mock.patch.object(
|
||||
gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API
|
||||
):
|
||||
assert gemini_llm._live_api_version == "v1alpha"
|
||||
|
||||
|
||||
def test_live_api_client_properties(gemini_llm):
|
||||
"""Test that _live_api_client is properly configured with tracking headers and API version."""
|
||||
with mock.patch.object(
|
||||
gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI
|
||||
):
|
||||
client = gemini_llm._live_api_client
|
||||
|
||||
# Verify that the client has the correct headers and API version
|
||||
http_options = client._api_client._http_options
|
||||
assert http_options.api_version == "v1beta1"
|
||||
|
||||
# Check that tracking headers are included
|
||||
tracking_headers = gemini_llm._tracking_headers
|
||||
for key, value in tracking_headers.items():
|
||||
assert key in http_options.headers
|
||||
assert value in http_options.headers[key]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_with_custom_headers(gemini_llm, llm_request):
|
||||
"""Test that connect method updates tracking headers and API version when custom headers are provided."""
|
||||
# Setup request with live connect config and custom headers
|
||||
custom_headers = {"custom-live-header": "live-value"}
|
||||
llm_request.live_connect_config = types.LiveConnectConfig(
|
||||
http_options=types.HttpOptions(headers=custom_headers)
|
||||
)
|
||||
|
||||
mock_live_session = mock.AsyncMock()
|
||||
|
||||
# Mock the _live_api_client to return a mock client
|
||||
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
|
||||
# Create a mock context manager
|
||||
class MockLiveConnect:
|
||||
|
||||
async def __aenter__(self):
|
||||
return mock_live_session
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
|
||||
|
||||
async with gemini_llm.connect(llm_request) as connection:
|
||||
# Verify that the connect method was called with the right config
|
||||
mock_live_client.aio.live.connect.assert_called_once()
|
||||
call_args = mock_live_client.aio.live.connect.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
# Verify that tracking headers were merged with custom headers
|
||||
expected_headers = custom_headers.copy()
|
||||
expected_headers.update(gemini_llm._tracking_headers)
|
||||
assert config_arg.http_options.headers == expected_headers
|
||||
|
||||
# Verify that API version was set
|
||||
assert config_arg.http_options.api_version == gemini_llm._live_api_version
|
||||
|
||||
# Verify that system instruction and tools were set
|
||||
assert config_arg.system_instruction is not None
|
||||
assert config_arg.tools == llm_request.config.tools
|
||||
|
||||
# Verify connection is properly wrapped
|
||||
assert isinstance(connection, GeminiLlmConnection)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_without_custom_headers(gemini_llm, llm_request):
|
||||
"""Test that connect method works properly when no custom headers are provided."""
|
||||
# Setup request with live connect config but no custom headers
|
||||
llm_request.live_connect_config = types.LiveConnectConfig()
|
||||
|
||||
mock_live_session = mock.AsyncMock()
|
||||
|
||||
with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client:
|
||||
|
||||
class MockLiveConnect:
|
||||
|
||||
async def __aenter__(self):
|
||||
return mock_live_session
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
|
||||
|
||||
async with gemini_llm.connect(llm_request) as connection:
|
||||
# Verify that the connect method was called with the right config
|
||||
mock_live_client.aio.live.connect.assert_called_once()
|
||||
call_args = mock_live_client.aio.live.connect.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
# Verify that http_options remains None since no custom headers were provided
|
||||
assert config_arg.http_options is None
|
||||
|
||||
# Verify that system instruction and tools were still set
|
||||
assert config_arg.system_instruction is not None
|
||||
assert config_arg.tools == llm_request.config.tools
|
||||
|
||||
assert isinstance(connection, GeminiLlmConnection)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"api_backend, "
|
||||
|
||||
Reference in New Issue
Block a user