diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index fd1d8e58..bff2b675 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -95,6 +95,13 @@ class Gemini(BaseLlm): ) logger.info(_build_request_log(llm_request)) + # add tracking headers to custom headers given it will override the headers + # set in the api client constructor + if llm_request.config and llm_request.config.http_options: + if not llm_request.config.http_options.headers: + llm_request.config.http_options.headers = {} + llm_request.config.http_options.headers.update(self._tracking_headers) + if stream: responses = await self.api_client.aio.models.generate_content_stream( model=llm_request.model, @@ -201,24 +208,21 @@ class Gemini(BaseLlm): return tracking_headers @cached_property - def _live_api_client(self) -> Client: + def _live_api_version(self) -> str: if self._api_backend == GoogleLLMVariant.VERTEX_AI: # use beta version for vertex api - api_version = 'v1beta1' - # use default api version for vertex - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=api_version - ) - ) + return 'v1beta1' else: # use v1alpha for using API KEY from Google AI Studio - api_version = 'v1alpha' - return Client( - http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=api_version - ) - ) + return 'v1alpha' + + @cached_property + def _live_api_client(self) -> Client: + return Client( + http_options=types.HttpOptions( + headers=self._tracking_headers, api_version=self._live_api_version + ) + ) @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: @@ -230,6 +234,21 @@ class Gemini(BaseLlm): Yields: BaseLlmConnection, the connection to the Gemini model. """ + # add tracking headers to custom headers and set api_version given + # the customized http options will override the one set in the api client + # constructor + if ( + llm_request.live_connect_config + and llm_request.live_connect_config.http_options + ): + if not llm_request.live_connect_config.http_options.headers: + llm_request.live_connect_config.http_options.headers = {} + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers + ) + llm_request.live_connect_config.http_options.api_version = ( + self._live_api_version + ) llm_request.live_connect_config.system_instruction = types.Content( role='system', diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 9278bee5..fb8540bb 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -341,6 +341,255 @@ async def test_connect(gemini_llm, llm_request): assert connection is mock_connection +@pytest.mark.asyncio +async def test_generate_content_async_with_custom_headers( + gemini_llm, llm_request, generate_content_response +): + """Test that tracking headers are updated when custom headers are provided.""" + # Add custom headers to the request config + custom_headers = {"custom-header": "custom-value"} + for key in gemini_llm._tracking_headers: + custom_headers[key] = "custom " + gemini_llm._tracking_headers[key] + llm_request.config.http_options = types.HttpOptions(headers=custom_headers) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + # Create a mock coroutine that returns the generate_content_response + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + # Verify that the config passed to generate_content contains merged headers + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args + config_arg = call_args.kwargs["config"] + + for key, value in config_arg.http_options.headers.items(): + if key in gemini_llm._tracking_headers: + assert value == gemini_llm._tracking_headers[key] + else: + assert value == custom_headers[key] + + assert len(responses) == 1 + assert isinstance(responses[0], LlmResponse) + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_custom_headers( + gemini_llm, llm_request +): + """Test that tracking headers are updated when custom headers are provided in streaming mode.""" + # Add custom headers to the request config + custom_headers = {"custom-header": "custom-value"} + llm_request.config.http_options = types.HttpOptions(headers=custom_headers) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + # Create mock stream responses + class MockAsyncIterator: + + def __init__(self, seq): + self.iter = iter(seq) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iter) + except StopIteration: + raise StopAsyncIteration + + mock_responses = [ + types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", parts=[Part.from_text(text="Hello")] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + ] + + async def mock_coro(): + return MockAsyncIterator(mock_responses) + + mock_client.aio.models.generate_content_stream.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=True + ) + ] + + # Verify that the config passed to generate_content_stream contains merged headers + mock_client.aio.models.generate_content_stream.assert_called_once() + call_args = mock_client.aio.models.generate_content_stream.call_args + config_arg = call_args.kwargs["config"] + + expected_headers = custom_headers.copy() + expected_headers.update(gemini_llm._tracking_headers) + assert config_arg.http_options.headers == expected_headers + + assert len(responses) == 2 + + +@pytest.mark.asyncio +async def test_generate_content_async_without_custom_headers( + gemini_llm, llm_request, generate_content_response +): + """Test that tracking headers are not modified when no custom headers exist.""" + # Ensure no http_options exist initially + llm_request.config.http_options = None + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + responses = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + # Verify that the config passed to generate_content has no http_options + mock_client.aio.models.generate_content.assert_called_once() + call_args = mock_client.aio.models.generate_content.call_args + config_arg = call_args.kwargs["config"] + assert config_arg.http_options is None + + assert len(responses) == 1 + + +def test_live_api_version_vertex_ai(gemini_llm): + """Test that _live_api_version returns 'v1beta1' for Vertex AI backend.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + assert gemini_llm._live_api_version == "v1beta1" + + +def test_live_api_version_gemini_api(gemini_llm): + """Test that _live_api_version returns 'v1alpha' for Gemini API backend.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API + ): + assert gemini_llm._live_api_version == "v1alpha" + + +def test_live_api_client_properties(gemini_llm): + """Test that _live_api_client is properly configured with tracking headers and API version.""" + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + client = gemini_llm._live_api_client + + # Verify that the client has the correct headers and API version + http_options = client._api_client._http_options + assert http_options.api_version == "v1beta1" + + # Check that tracking headers are included + tracking_headers = gemini_llm._tracking_headers + for key, value in tracking_headers.items(): + assert key in http_options.headers + assert value in http_options.headers[key] + + +@pytest.mark.asyncio +async def test_connect_with_custom_headers(gemini_llm, llm_request): + """Test that connect method updates tracking headers and API version when custom headers are provided.""" + # Setup request with live connect config and custom headers + custom_headers = {"custom-live-header": "live-value"} + llm_request.live_connect_config = types.LiveConnectConfig( + http_options=types.HttpOptions(headers=custom_headers) + ) + + mock_live_session = mock.AsyncMock() + + # Mock the _live_api_client to return a mock client + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + # Create a mock context manager + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify that tracking headers were merged with custom headers + expected_headers = custom_headers.copy() + expected_headers.update(gemini_llm._tracking_headers) + assert config_arg.http_options.headers == expected_headers + + # Verify that API version was set + assert config_arg.http_options.api_version == gemini_llm._live_api_version + + # Verify that system instruction and tools were set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools + + # Verify connection is properly wrapped + assert isinstance(connection, GeminiLlmConnection) + + +@pytest.mark.asyncio +async def test_connect_without_custom_headers(gemini_llm, llm_request): + """Test that connect method works properly when no custom headers are provided.""" + # Setup request with live connect config but no custom headers + llm_request.live_connect_config = types.LiveConnectConfig() + + mock_live_session = mock.AsyncMock() + + with mock.patch.object(gemini_llm, "_live_api_client") as mock_live_client: + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + mock_live_client.aio.live.connect.return_value = MockLiveConnect() + + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] + + # Verify that http_options remains None since no custom headers were provided + assert config_arg.http_options is None + + # Verify that system instruction and tools were still set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools + + assert isinstance(connection, GeminiLlmConnection) + + @pytest.mark.parametrize( ( "api_backend, "