feat: Mark Vertex calls made from non-gemini models

PiperOrigin-RevId: 848159669
This commit is contained in:
Google Team Member
2025-12-23 06:50:06 -08:00
committed by Copybara-Service
parent 0f5b677c53
commit 871571d997
9 changed files with 41 additions and 234 deletions
-2
View File
@@ -36,7 +36,6 @@ from google.genai import types
from pydantic import BaseModel
from typing_extensions import override
from ..utils._google_client_headers import get_tracking_headers
from .base_llm import BaseLlm
from .llm_response import LlmResponse
@@ -346,5 +345,4 @@ class Claude(AnthropicLlm):
return AsyncAnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
default_headers=get_tracking_headers(),
)
+1 -2
View File
@@ -25,7 +25,6 @@ from google.adk import version as adk_version
from google.genai import types
from typing_extensions import override
from ..utils._google_client_headers import merge_tracking_headers
from ..utils.env_utils import is_env_enabled
from .google_llm import Gemini
@@ -146,7 +145,7 @@ class ApigeeLlm(Gemini):
kwargs_for_http_options['api_version'] = self._api_version
http_options = types.HttpOptions(
base_url=self._proxy_url,
headers=merge_tracking_headers(self._custom_headers),
headers=self._merge_tracking_headers(self._custom_headers),
retry_options=self.retry_options,
**kwargs_for_http_options,
)
+32 -9
View File
@@ -30,8 +30,7 @@ from google.genai import types
from google.genai.errors import ClientError
from typing_extensions import override
from ..utils._google_client_headers import get_tracking_headers
from ..utils._google_client_headers import merge_tracking_headers
from ..utils._client_labels_utils import get_client_labels
from ..utils.context_utils import Aclosing
from ..utils.streaming_utils import StreamingResponseAggregator
from ..utils.variant_utils import GoogleLLMVariant
@@ -192,7 +191,7 @@ class Gemini(BaseLlm):
if llm_request.config:
if not llm_request.config.http_options:
llm_request.config.http_options = types.HttpOptions()
llm_request.config.http_options.headers = merge_tracking_headers(
llm_request.config.http_options.headers = self._merge_tracking_headers(
llm_request.config.http_options.headers
)
@@ -303,7 +302,7 @@ class Gemini(BaseLlm):
return Client(
http_options=types.HttpOptions(
headers=get_tracking_headers(),
headers=self._tracking_headers(),
retry_options=self.retry_options,
)
)
@@ -316,6 +315,15 @@ class Gemini(BaseLlm):
else GoogleLLMVariant.GEMINI_API
)
def _tracking_headers(self) -> dict[str, str]:
labels = get_client_labels()
header_value = ' '.join(labels)
tracking_headers = {
'x-goog-api-client': header_value,
'user-agent': header_value,
}
return tracking_headers
@cached_property
def _live_api_version(self) -> str:
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
@@ -331,7 +339,7 @@ class Gemini(BaseLlm):
return Client(
http_options=types.HttpOptions(
headers=get_tracking_headers(), api_version=self._live_api_version
headers=self._tracking_headers(), api_version=self._live_api_version
)
)
@@ -354,10 +362,8 @@ class Gemini(BaseLlm):
):
if not llm_request.live_connect_config.http_options.headers:
llm_request.live_connect_config.http_options.headers = {}
llm_request.live_connect_config.http_options.headers = (
merge_tracking_headers(
llm_request.live_connect_config.http_options.headers
)
llm_request.live_connect_config.http_options.headers.update(
self._tracking_headers()
)
llm_request.live_connect_config.http_options.api_version = (
self._live_api_version
@@ -441,6 +447,23 @@ class Gemini(BaseLlm):
llm_request.config.system_instruction = None
await self._adapt_computer_use_tool(llm_request)
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Merge tracking headers to the given headers."""
headers = headers or {}
for key, tracking_header_value in self._tracking_headers().items():
custom_value = headers.get(key, None)
if not custom_value:
headers[key] = tracking_header_value
continue
# Merge tracking headers with existing headers and avoid duplicates.
value_parts = tracking_header_value.split(' ')
for custom_value_part in custom_value.split(' '):
if custom_value_part not in value_parts:
value_parts.append(custom_value_part)
headers[key] = ' '.join(value_parts)
return headers
def _build_function_declaration_log(
func_decl: types.FunctionDeclaration,
-21
View File
@@ -57,7 +57,6 @@ from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
from ..utils._google_client_headers import merge_tracking_headers
from .base_llm import BaseLlm
from .llm_request import LlmRequest
from .llm_response import LlmResponse
@@ -1391,18 +1390,6 @@ Functions:
"""
def _is_litellm_vertex_model(model_string: str) -> bool:
"""Check if the model is a Vertex AI model accessed via LiteLLM.
Args:
model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash")
Returns:
True if it's a Vertex AI model accessed via LiteLLM, False otherwise
"""
return model_string.startswith("vertex_ai/")
def _is_litellm_gemini_model(model_string: str) -> bool:
"""Check if the model is a Gemini model accessed via LiteLLM.
@@ -1575,14 +1562,6 @@ class LiteLlm(BaseLlm):
}
completion_args.update(self._additional_args)
# merge headers
if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model(
effective_model
):
completion_args["headers"] = merge_tracking_headers(
completion_args.get("headers")
)
if generation_params:
completion_args.update(generation_params)
@@ -1,56 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._client_labels_utils import get_client_labels
def get_tracking_headers() -> dict[str, str]:
"""Returns a dictionary of HTTP headers for tracking API requests.
These headers are used to identify HTTP calls made by ADK towards
Vertex AI LLM APIs.
"""
labels = get_client_labels()
header_value = " ".join(labels)
return {
"x-goog-api-client": header_value,
"user-agent": header_value,
}
def merge_tracking_headers(headers: dict[str, str] | None) -> dict[str, str]:
"""Merge tracking headers to the given headers.
Args:
headers: headers to merge tracking headers into.
Returns:
A dictionary of HTTP headers with tracking headers merged.
"""
new_headers = (headers or {}).copy()
for key, tracking_header_value in get_tracking_headers().items():
custom_value = new_headers.get(key, None)
if not custom_value:
new_headers[key] = tracking_header_value
continue
# Merge tracking headers with existing headers and avoid duplicates.
value_parts = tracking_header_value.split(" ")
for custom_value_part in custom_value.split(" "):
if custom_value_part not in value_parts:
value_parts.append(custom_value_part)
new_headers[key] = " ".join(value_parts)
return new_headers
@@ -391,31 +391,6 @@ async def test_anthropic_llm_generate_content_async(
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
def test_claude_vertex_client_uses_tracking_headers():
"""Tests that Claude vertex client is called with tracking headers."""
with mock.patch.object(
anthropic_llm, "AsyncAnthropicVertex", autospec=True
) as mock_anthropic_vertex:
with mock.patch.dict(
os.environ,
{
"GOOGLE_CLOUD_PROJECT": "test-project",
"GOOGLE_CLOUD_LOCATION": "us-central1",
},
):
instance = Claude(model="claude-3-5-sonnet-v2@20241022")
_ = instance._anthropic_client
mock_anthropic_vertex.assert_called_once()
_, kwargs = mock_anthropic_vertex.call_args
assert "default_headers" in kwargs
assert "x-goog-api-client" in kwargs["default_headers"]
assert "user-agent" in kwargs["default_headers"]
assert (
f"google-adk/{adk_version.__version__}"
in kwargs["default_headers"]["user-agent"]
)
@pytest.mark.asyncio
async def test_generate_content_async_with_max_tokens(
llm_request, generate_content_response, generate_llm_response
+6 -7
View File
@@ -31,7 +31,6 @@ from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
from google.adk.utils._google_client_headers import get_tracking_headers
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai.errors import ClientError
@@ -470,7 +469,7 @@ async def test_generate_content_async_with_custom_headers(
"""Test that tracking headers are updated when custom headers are provided."""
# Add custom headers to the request config
custom_headers = {"custom-header": "custom-value"}
tracking_headers = get_tracking_headers()
tracking_headers = gemini_llm._tracking_headers()
for key in tracking_headers:
custom_headers[key] = "custom " + tracking_headers[key]
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
@@ -495,7 +494,7 @@ async def test_generate_content_async_with_custom_headers(
config_arg = call_args.kwargs["config"]
for key, value in config_arg.http_options.headers.items():
tracking_headers = get_tracking_headers()
tracking_headers = gemini_llm._tracking_headers()
if key in tracking_headers:
assert value == tracking_headers[key] + " custom"
else:
@@ -546,7 +545,7 @@ async def test_generate_content_async_stream_with_custom_headers(
config_arg = call_args.kwargs["config"]
expected_headers = custom_headers.copy()
expected_headers.update(get_tracking_headers())
expected_headers.update(gemini_llm._tracking_headers())
assert config_arg.http_options.headers == expected_headers
assert len(responses) == 2
@@ -600,7 +599,7 @@ async def test_generate_content_async_patches_tracking_headers(
assert final_config.http_options is not None
assert (
final_config.http_options.headers["x-goog-api-client"]
== get_tracking_headers()["x-goog-api-client"]
== gemini_llm._tracking_headers()["x-goog-api-client"]
)
assert len(responses) == 2 if stream else 1
@@ -634,7 +633,7 @@ def test_live_api_client_properties(gemini_llm):
assert http_options.api_version == "v1beta1"
# Check that tracking headers are included
tracking_headers = get_tracking_headers()
tracking_headers = gemini_llm._tracking_headers()
for key, value in tracking_headers.items():
assert key in http_options.headers
assert value in http_options.headers[key]
@@ -672,7 +671,7 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request):
# Verify that tracking headers were merged with custom headers
expected_headers = custom_headers.copy()
expected_headers.update(get_tracking_headers())
expected_headers.update(gemini_llm._tracking_headers())
assert config_arg.http_options.headers == expected_headers
# Verify that API version was set
+2 -33
View File
@@ -2447,12 +2447,11 @@ def test_model_response_to_chunk(
async def test_acompletion_additional_args(mock_acompletion, mock_client):
lite_llm_instance = LiteLlm(
# valid args
model="vertex_ai/test_model",
model="test_model",
llm_client=mock_client,
api_key="test_key",
api_base="some://url",
api_version="2024-09-12",
headers={"custom": "header"}, # Add custom header to test merge
# invalid args (ignored)
stream=True,
messages=[{"role": "invalid", "content": "invalid"}],
@@ -2479,43 +2478,13 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "vertex_ai/test_model"
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert "stream" not in kwargs
assert "llm_client" not in kwargs
assert kwargs["api_base"] == "some://url"
assert "headers" in kwargs
assert kwargs["headers"]["custom"] == "header"
assert "x-goog-api-client" in kwargs["headers"]
assert "user-agent" in kwargs["headers"]
@pytest.mark.asyncio
async def test_acompletion_additional_args_non_vertex(
mock_acompletion, mock_client
):
"""Test that tracking headers are not added for non-Vertex AI models."""
lite_llm_instance = LiteLlm(
model="openai/gpt-4o",
llm_client=mock_client,
api_key="test_key",
headers={"custom": "header"},
)
async for _ in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
pass
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "openai/gpt-4o"
assert "headers" in kwargs
assert kwargs["headers"]["custom"] == "header"
assert "x-goog-api-client" not in kwargs["headers"]
assert "user-agent" not in kwargs["headers"]
@pytest.mark.asyncio
@@ -1,79 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from google.adk import version
from google.adk.utils import _google_client_headers
import pytest
_EXPECTED_BASE_HEADER = (
f"google-adk/{version.__version__} gl-python/{sys.version.split()[0]}"
)
def test_get_tracking_headers():
"""Test get_tracking_headers returns correct headers."""
headers = _google_client_headers.get_tracking_headers()
assert headers == {
"x-goog-api-client": _EXPECTED_BASE_HEADER,
"user-agent": _EXPECTED_BASE_HEADER,
}
@pytest.mark.parametrize(
"input_headers, expected_headers",
[
(
None,
{
"x-goog-api-client": _EXPECTED_BASE_HEADER,
"user-agent": _EXPECTED_BASE_HEADER,
},
),
(
{},
{
"x-goog-api-client": _EXPECTED_BASE_HEADER,
"user-agent": _EXPECTED_BASE_HEADER,
},
),
(
{"x-goog-api-client": "label3 label4"},
{
"x-goog-api-client": f"{_EXPECTED_BASE_HEADER} label3 label4",
"user-agent": _EXPECTED_BASE_HEADER,
},
),
(
{"x-goog-api-client": f"gl-python/{sys.version.split()[0]} label3"},
{
"x-goog-api-client": f"{_EXPECTED_BASE_HEADER} label3",
"user-agent": _EXPECTED_BASE_HEADER,
},
),
(
{"other-header": "value"},
{
"x-goog-api-client": _EXPECTED_BASE_HEADER,
"user-agent": _EXPECTED_BASE_HEADER,
"other-header": "value",
},
),
],
)
def test_merge_tracking_headers(input_headers, expected_headers):
"""Test merge_tracking_headers with various inputs."""
headers = _google_client_headers.merge_tracking_headers(input_headers)
assert headers == expected_headers