From 871571d997dedd9430d421cc366920321abbab5d Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 23 Dec 2025 06:50:06 -0800 Subject: [PATCH] feat: Mark Vertex calls made from non-gemini models PiperOrigin-RevId: 848159669 --- src/google/adk/models/anthropic_llm.py | 2 - src/google/adk/models/apigee_llm.py | 3 +- src/google/adk/models/google_llm.py | 41 +++++++--- src/google/adk/models/lite_llm.py | 21 ----- .../adk/utils/_google_client_headers.py | 56 ------------- tests/unittests/models/test_anthropic_llm.py | 25 ------ tests/unittests/models/test_google_llm.py | 13 ++- tests/unittests/models/test_litellm.py | 35 +------- .../utils/test_google_client_headers.py | 79 ------------------- 9 files changed, 41 insertions(+), 234 deletions(-) delete mode 100644 src/google/adk/utils/_google_client_headers.py delete mode 100644 tests/unittests/utils/test_google_client_headers.py diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index e39a85fd..163fbe45 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -36,7 +36,6 @@ from google.genai import types from pydantic import BaseModel from typing_extensions import override -from ..utils._google_client_headers import get_tracking_headers from .base_llm import BaseLlm from .llm_response import LlmResponse @@ -346,5 +345,4 @@ class Claude(AnthropicLlm): return AsyncAnthropicVertex( project_id=os.environ["GOOGLE_CLOUD_PROJECT"], region=os.environ["GOOGLE_CLOUD_LOCATION"], - default_headers=get_tracking_headers(), ) diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 167f3d1f..a2962021 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -25,7 +25,6 @@ from google.adk import version as adk_version from google.genai import types from typing_extensions import override -from ..utils._google_client_headers import merge_tracking_headers from ..utils.env_utils import is_env_enabled from .google_llm import Gemini @@ -146,7 +145,7 @@ class ApigeeLlm(Gemini): kwargs_for_http_options['api_version'] = self._api_version http_options = types.HttpOptions( base_url=self._proxy_url, - headers=merge_tracking_headers(self._custom_headers), + headers=self._merge_tracking_headers(self._custom_headers), retry_options=self.retry_options, **kwargs_for_http_options, ) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 35afe25c..9261fada 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -30,8 +30,7 @@ from google.genai import types from google.genai.errors import ClientError from typing_extensions import override -from ..utils._google_client_headers import get_tracking_headers -from ..utils._google_client_headers import merge_tracking_headers +from ..utils._client_labels_utils import get_client_labels from ..utils.context_utils import Aclosing from ..utils.streaming_utils import StreamingResponseAggregator from ..utils.variant_utils import GoogleLLMVariant @@ -192,7 +191,7 @@ class Gemini(BaseLlm): if llm_request.config: if not llm_request.config.http_options: llm_request.config.http_options = types.HttpOptions() - llm_request.config.http_options.headers = merge_tracking_headers( + llm_request.config.http_options.headers = self._merge_tracking_headers( llm_request.config.http_options.headers ) @@ -303,7 +302,7 @@ class Gemini(BaseLlm): return Client( http_options=types.HttpOptions( - headers=get_tracking_headers(), + headers=self._tracking_headers(), retry_options=self.retry_options, ) ) @@ -316,6 +315,15 @@ class Gemini(BaseLlm): else GoogleLLMVariant.GEMINI_API ) + def _tracking_headers(self) -> dict[str, str]: + labels = get_client_labels() + header_value = ' '.join(labels) + tracking_headers = { + 'x-goog-api-client': header_value, + 'user-agent': header_value, + } + return tracking_headers + @cached_property def _live_api_version(self) -> str: if self._api_backend == GoogleLLMVariant.VERTEX_AI: @@ -331,7 +339,7 @@ class Gemini(BaseLlm): return Client( http_options=types.HttpOptions( - headers=get_tracking_headers(), api_version=self._live_api_version + headers=self._tracking_headers(), api_version=self._live_api_version ) ) @@ -354,10 +362,8 @@ class Gemini(BaseLlm): ): if not llm_request.live_connect_config.http_options.headers: llm_request.live_connect_config.http_options.headers = {} - llm_request.live_connect_config.http_options.headers = ( - merge_tracking_headers( - llm_request.live_connect_config.http_options.headers - ) + llm_request.live_connect_config.http_options.headers.update( + self._tracking_headers() ) llm_request.live_connect_config.http_options.api_version = ( self._live_api_version @@ -441,6 +447,23 @@ class Gemini(BaseLlm): llm_request.config.system_instruction = None await self._adapt_computer_use_tool(llm_request) + def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: + """Merge tracking headers to the given headers.""" + headers = headers or {} + for key, tracking_header_value in self._tracking_headers().items(): + custom_value = headers.get(key, None) + if not custom_value: + headers[key] = tracking_header_value + continue + + # Merge tracking headers with existing headers and avoid duplicates. + value_parts = tracking_header_value.split(' ') + for custom_value_part in custom_value.split(' '): + if custom_value_part not in value_parts: + value_parts.append(custom_value_part) + headers[key] = ' '.join(value_parts) + return headers + def _build_function_declaration_log( func_decl: types.FunctionDeclaration, diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index f5331d7c..14047398 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -57,7 +57,6 @@ from pydantic import BaseModel from pydantic import Field from typing_extensions import override -from ..utils._google_client_headers import merge_tracking_headers from .base_llm import BaseLlm from .llm_request import LlmRequest from .llm_response import LlmResponse @@ -1391,18 +1390,6 @@ Functions: """ -def _is_litellm_vertex_model(model_string: str) -> bool: - """Check if the model is a Vertex AI model accessed via LiteLLM. - - Args: - model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash") - - Returns: - True if it's a Vertex AI model accessed via LiteLLM, False otherwise - """ - return model_string.startswith("vertex_ai/") - - def _is_litellm_gemini_model(model_string: str) -> bool: """Check if the model is a Gemini model accessed via LiteLLM. @@ -1575,14 +1562,6 @@ class LiteLlm(BaseLlm): } completion_args.update(self._additional_args) - # merge headers - if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model( - effective_model - ): - completion_args["headers"] = merge_tracking_headers( - completion_args.get("headers") - ) - if generation_params: completion_args.update(generation_params) diff --git a/src/google/adk/utils/_google_client_headers.py b/src/google/adk/utils/_google_client_headers.py deleted file mode 100644 index 14408178..00000000 --- a/src/google/adk/utils/_google_client_headers.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from ._client_labels_utils import get_client_labels - - -def get_tracking_headers() -> dict[str, str]: - """Returns a dictionary of HTTP headers for tracking API requests. - - These headers are used to identify HTTP calls made by ADK towards - Vertex AI LLM APIs. - """ - labels = get_client_labels() - header_value = " ".join(labels) - return { - "x-goog-api-client": header_value, - "user-agent": header_value, - } - - -def merge_tracking_headers(headers: dict[str, str] | None) -> dict[str, str]: - """Merge tracking headers to the given headers. - - Args: - headers: headers to merge tracking headers into. - - Returns: - A dictionary of HTTP headers with tracking headers merged. - """ - new_headers = (headers or {}).copy() - for key, tracking_header_value in get_tracking_headers().items(): - custom_value = new_headers.get(key, None) - if not custom_value: - new_headers[key] = tracking_header_value - continue - - # Merge tracking headers with existing headers and avoid duplicates. - value_parts = tracking_header_value.split(" ") - for custom_value_part in custom_value.split(" "): - if custom_value_part not in value_parts: - value_parts.append(custom_value_part) - new_headers[key] = " ".join(value_parts) - return new_headers diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index e38cbc45..e1880abf 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -391,31 +391,6 @@ async def test_anthropic_llm_generate_content_async( assert responses[0].content.parts[0].text == "Hello, how can I help you?" -def test_claude_vertex_client_uses_tracking_headers(): - """Tests that Claude vertex client is called with tracking headers.""" - with mock.patch.object( - anthropic_llm, "AsyncAnthropicVertex", autospec=True - ) as mock_anthropic_vertex: - with mock.patch.dict( - os.environ, - { - "GOOGLE_CLOUD_PROJECT": "test-project", - "GOOGLE_CLOUD_LOCATION": "us-central1", - }, - ): - instance = Claude(model="claude-3-5-sonnet-v2@20241022") - _ = instance._anthropic_client - mock_anthropic_vertex.assert_called_once() - _, kwargs = mock_anthropic_vertex.call_args - assert "default_headers" in kwargs - assert "x-goog-api-client" in kwargs["default_headers"] - assert "user-agent" in kwargs["default_headers"] - assert ( - f"google-adk/{adk_version.__version__}" - in kwargs["default_headers"]["user-agent"] - ) - - @pytest.mark.asyncio async def test_generate_content_async_with_max_tokens( llm_request, generate_content_response, generate_llm_response diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 4f966d3d..c42260a2 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -31,7 +31,6 @@ from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG -from google.adk.utils._google_client_headers import get_tracking_headers from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.errors import ClientError @@ -470,7 +469,7 @@ async def test_generate_content_async_with_custom_headers( """Test that tracking headers are updated when custom headers are provided.""" # Add custom headers to the request config custom_headers = {"custom-header": "custom-value"} - tracking_headers = get_tracking_headers() + tracking_headers = gemini_llm._tracking_headers() for key in tracking_headers: custom_headers[key] = "custom " + tracking_headers[key] llm_request.config.http_options = types.HttpOptions(headers=custom_headers) @@ -495,7 +494,7 @@ async def test_generate_content_async_with_custom_headers( config_arg = call_args.kwargs["config"] for key, value in config_arg.http_options.headers.items(): - tracking_headers = get_tracking_headers() + tracking_headers = gemini_llm._tracking_headers() if key in tracking_headers: assert value == tracking_headers[key] + " custom" else: @@ -546,7 +545,7 @@ async def test_generate_content_async_stream_with_custom_headers( config_arg = call_args.kwargs["config"] expected_headers = custom_headers.copy() - expected_headers.update(get_tracking_headers()) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers assert len(responses) == 2 @@ -600,7 +599,7 @@ async def test_generate_content_async_patches_tracking_headers( assert final_config.http_options is not None assert ( final_config.http_options.headers["x-goog-api-client"] - == get_tracking_headers()["x-goog-api-client"] + == gemini_llm._tracking_headers()["x-goog-api-client"] ) assert len(responses) == 2 if stream else 1 @@ -634,7 +633,7 @@ def test_live_api_client_properties(gemini_llm): assert http_options.api_version == "v1beta1" # Check that tracking headers are included - tracking_headers = get_tracking_headers() + tracking_headers = gemini_llm._tracking_headers() for key, value in tracking_headers.items(): assert key in http_options.headers assert value in http_options.headers[key] @@ -672,7 +671,7 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request): # Verify that tracking headers were merged with custom headers expected_headers = custom_headers.copy() - expected_headers.update(get_tracking_headers()) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers # Verify that API version was set diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 7df83e14..4cf0329a 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2447,12 +2447,11 @@ def test_model_response_to_chunk( async def test_acompletion_additional_args(mock_acompletion, mock_client): lite_llm_instance = LiteLlm( # valid args - model="vertex_ai/test_model", + model="test_model", llm_client=mock_client, api_key="test_key", api_base="some://url", api_version="2024-09-12", - headers={"custom": "header"}, # Add custom header to test merge # invalid args (ignored) stream=True, messages=[{"role": "invalid", "content": "invalid"}], @@ -2479,43 +2478,13 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client): _, kwargs = mock_acompletion.call_args - assert kwargs["model"] == "vertex_ai/test_model" + assert kwargs["model"] == "test_model" assert kwargs["messages"][0]["role"] == "user" assert kwargs["messages"][0]["content"] == "Test prompt" assert kwargs["tools"][0]["function"]["name"] == "test_function" assert "stream" not in kwargs assert "llm_client" not in kwargs assert kwargs["api_base"] == "some://url" - assert "headers" in kwargs - assert kwargs["headers"]["custom"] == "header" - assert "x-goog-api-client" in kwargs["headers"] - assert "user-agent" in kwargs["headers"] - - -@pytest.mark.asyncio -async def test_acompletion_additional_args_non_vertex( - mock_acompletion, mock_client -): - """Test that tracking headers are not added for non-Vertex AI models.""" - lite_llm_instance = LiteLlm( - model="openai/gpt-4o", - llm_client=mock_client, - api_key="test_key", - headers={"custom": "header"}, - ) - - async for _ in lite_llm_instance.generate_content_async( - LLM_REQUEST_WITH_FUNCTION_DECLARATION - ): - pass - - mock_acompletion.assert_called_once() - _, kwargs = mock_acompletion.call_args - assert kwargs["model"] == "openai/gpt-4o" - assert "headers" in kwargs - assert kwargs["headers"]["custom"] == "header" - assert "x-goog-api-client" not in kwargs["headers"] - assert "user-agent" not in kwargs["headers"] @pytest.mark.asyncio diff --git a/tests/unittests/utils/test_google_client_headers.py b/tests/unittests/utils/test_google_client_headers.py deleted file mode 100644 index e7cc0296..00000000 --- a/tests/unittests/utils/test_google_client_headers.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -from google.adk import version -from google.adk.utils import _google_client_headers -import pytest - -_EXPECTED_BASE_HEADER = ( - f"google-adk/{version.__version__} gl-python/{sys.version.split()[0]}" -) - - -def test_get_tracking_headers(): - """Test get_tracking_headers returns correct headers.""" - headers = _google_client_headers.get_tracking_headers() - assert headers == { - "x-goog-api-client": _EXPECTED_BASE_HEADER, - "user-agent": _EXPECTED_BASE_HEADER, - } - - -@pytest.mark.parametrize( - "input_headers, expected_headers", - [ - ( - None, - { - "x-goog-api-client": _EXPECTED_BASE_HEADER, - "user-agent": _EXPECTED_BASE_HEADER, - }, - ), - ( - {}, - { - "x-goog-api-client": _EXPECTED_BASE_HEADER, - "user-agent": _EXPECTED_BASE_HEADER, - }, - ), - ( - {"x-goog-api-client": "label3 label4"}, - { - "x-goog-api-client": f"{_EXPECTED_BASE_HEADER} label3 label4", - "user-agent": _EXPECTED_BASE_HEADER, - }, - ), - ( - {"x-goog-api-client": f"gl-python/{sys.version.split()[0]} label3"}, - { - "x-goog-api-client": f"{_EXPECTED_BASE_HEADER} label3", - "user-agent": _EXPECTED_BASE_HEADER, - }, - ), - ( - {"other-header": "value"}, - { - "x-goog-api-client": _EXPECTED_BASE_HEADER, - "user-agent": _EXPECTED_BASE_HEADER, - "other-header": "value", - }, - ), - ], -) -def test_merge_tracking_headers(input_headers, expected_headers): - """Test merge_tracking_headers with various inputs.""" - headers = _google_client_headers.merge_tracking_headers(input_headers) - assert headers == expected_headers