You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Mark Vertex calls made from non-gemini models
PiperOrigin-RevId: 848159669
This commit is contained in:
committed by
Copybara-Service
parent
0f5b677c53
commit
871571d997
@@ -36,7 +36,6 @@ from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from ..utils._google_client_headers import get_tracking_headers
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
@@ -346,5 +345,4 @@ class Claude(AnthropicLlm):
|
||||
return AsyncAnthropicVertex(
|
||||
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
region=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
default_headers=get_tracking_headers(),
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ from google.adk import version as adk_version
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..utils._google_client_headers import merge_tracking_headers
|
||||
from ..utils.env_utils import is_env_enabled
|
||||
from .google_llm import Gemini
|
||||
|
||||
@@ -146,7 +145,7 @@ class ApigeeLlm(Gemini):
|
||||
kwargs_for_http_options['api_version'] = self._api_version
|
||||
http_options = types.HttpOptions(
|
||||
base_url=self._proxy_url,
|
||||
headers=merge_tracking_headers(self._custom_headers),
|
||||
headers=self._merge_tracking_headers(self._custom_headers),
|
||||
retry_options=self.retry_options,
|
||||
**kwargs_for_http_options,
|
||||
)
|
||||
|
||||
@@ -30,8 +30,7 @@ from google.genai import types
|
||||
from google.genai.errors import ClientError
|
||||
from typing_extensions import override
|
||||
|
||||
from ..utils._google_client_headers import get_tracking_headers
|
||||
from ..utils._google_client_headers import merge_tracking_headers
|
||||
from ..utils._client_labels_utils import get_client_labels
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.streaming_utils import StreamingResponseAggregator
|
||||
from ..utils.variant_utils import GoogleLLMVariant
|
||||
@@ -192,7 +191,7 @@ class Gemini(BaseLlm):
|
||||
if llm_request.config:
|
||||
if not llm_request.config.http_options:
|
||||
llm_request.config.http_options = types.HttpOptions()
|
||||
llm_request.config.http_options.headers = merge_tracking_headers(
|
||||
llm_request.config.http_options.headers = self._merge_tracking_headers(
|
||||
llm_request.config.http_options.headers
|
||||
)
|
||||
|
||||
@@ -303,7 +302,7 @@ class Gemini(BaseLlm):
|
||||
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=get_tracking_headers(),
|
||||
headers=self._tracking_headers(),
|
||||
retry_options=self.retry_options,
|
||||
)
|
||||
)
|
||||
@@ -316,6 +315,15 @@ class Gemini(BaseLlm):
|
||||
else GoogleLLMVariant.GEMINI_API
|
||||
)
|
||||
|
||||
def _tracking_headers(self) -> dict[str, str]:
|
||||
labels = get_client_labels()
|
||||
header_value = ' '.join(labels)
|
||||
tracking_headers = {
|
||||
'x-goog-api-client': header_value,
|
||||
'user-agent': header_value,
|
||||
}
|
||||
return tracking_headers
|
||||
|
||||
@cached_property
|
||||
def _live_api_version(self) -> str:
|
||||
if self._api_backend == GoogleLLMVariant.VERTEX_AI:
|
||||
@@ -331,7 +339,7 @@ class Gemini(BaseLlm):
|
||||
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=get_tracking_headers(), api_version=self._live_api_version
|
||||
headers=self._tracking_headers(), api_version=self._live_api_version
|
||||
)
|
||||
)
|
||||
|
||||
@@ -354,10 +362,8 @@ class Gemini(BaseLlm):
|
||||
):
|
||||
if not llm_request.live_connect_config.http_options.headers:
|
||||
llm_request.live_connect_config.http_options.headers = {}
|
||||
llm_request.live_connect_config.http_options.headers = (
|
||||
merge_tracking_headers(
|
||||
llm_request.live_connect_config.http_options.headers
|
||||
)
|
||||
llm_request.live_connect_config.http_options.headers.update(
|
||||
self._tracking_headers()
|
||||
)
|
||||
llm_request.live_connect_config.http_options.api_version = (
|
||||
self._live_api_version
|
||||
@@ -441,6 +447,23 @@ class Gemini(BaseLlm):
|
||||
llm_request.config.system_instruction = None
|
||||
await self._adapt_computer_use_tool(llm_request)
|
||||
|
||||
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Merge tracking headers to the given headers."""
|
||||
headers = headers or {}
|
||||
for key, tracking_header_value in self._tracking_headers().items():
|
||||
custom_value = headers.get(key, None)
|
||||
if not custom_value:
|
||||
headers[key] = tracking_header_value
|
||||
continue
|
||||
|
||||
# Merge tracking headers with existing headers and avoid duplicates.
|
||||
value_parts = tracking_header_value.split(' ')
|
||||
for custom_value_part in custom_value.split(' '):
|
||||
if custom_value_part not in value_parts:
|
||||
value_parts.append(custom_value_part)
|
||||
headers[key] = ' '.join(value_parts)
|
||||
return headers
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
|
||||
@@ -57,7 +57,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from ..utils._google_client_headers import merge_tracking_headers
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
@@ -1391,18 +1390,6 @@ Functions:
|
||||
"""
|
||||
|
||||
|
||||
def _is_litellm_vertex_model(model_string: str) -> bool:
|
||||
"""Check if the model is a Vertex AI model accessed via LiteLLM.
|
||||
|
||||
Args:
|
||||
model_string: A LiteLLM model string (e.g., "vertex_ai/gemini-2.5-flash")
|
||||
|
||||
Returns:
|
||||
True if it's a Vertex AI model accessed via LiteLLM, False otherwise
|
||||
"""
|
||||
return model_string.startswith("vertex_ai/")
|
||||
|
||||
|
||||
def _is_litellm_gemini_model(model_string: str) -> bool:
|
||||
"""Check if the model is a Gemini model accessed via LiteLLM.
|
||||
|
||||
@@ -1575,14 +1562,6 @@ class LiteLlm(BaseLlm):
|
||||
}
|
||||
completion_args.update(self._additional_args)
|
||||
|
||||
# merge headers
|
||||
if _is_litellm_vertex_model(effective_model) or _is_litellm_gemini_model(
|
||||
effective_model
|
||||
):
|
||||
completion_args["headers"] = merge_tracking_headers(
|
||||
completion_args.get("headers")
|
||||
)
|
||||
|
||||
if generation_params:
|
||||
completion_args.update(generation_params)
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._client_labels_utils import get_client_labels
|
||||
|
||||
|
||||
def get_tracking_headers() -> dict[str, str]:
|
||||
"""Returns a dictionary of HTTP headers for tracking API requests.
|
||||
|
||||
These headers are used to identify HTTP calls made by ADK towards
|
||||
Vertex AI LLM APIs.
|
||||
"""
|
||||
labels = get_client_labels()
|
||||
header_value = " ".join(labels)
|
||||
return {
|
||||
"x-goog-api-client": header_value,
|
||||
"user-agent": header_value,
|
||||
}
|
||||
|
||||
|
||||
def merge_tracking_headers(headers: dict[str, str] | None) -> dict[str, str]:
|
||||
"""Merge tracking headers to the given headers.
|
||||
|
||||
Args:
|
||||
headers: headers to merge tracking headers into.
|
||||
|
||||
Returns:
|
||||
A dictionary of HTTP headers with tracking headers merged.
|
||||
"""
|
||||
new_headers = (headers or {}).copy()
|
||||
for key, tracking_header_value in get_tracking_headers().items():
|
||||
custom_value = new_headers.get(key, None)
|
||||
if not custom_value:
|
||||
new_headers[key] = tracking_header_value
|
||||
continue
|
||||
|
||||
# Merge tracking headers with existing headers and avoid duplicates.
|
||||
value_parts = tracking_header_value.split(" ")
|
||||
for custom_value_part in custom_value.split(" "):
|
||||
if custom_value_part not in value_parts:
|
||||
value_parts.append(custom_value_part)
|
||||
new_headers[key] = " ".join(value_parts)
|
||||
return new_headers
|
||||
@@ -391,31 +391,6 @@ async def test_anthropic_llm_generate_content_async(
|
||||
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
|
||||
|
||||
|
||||
def test_claude_vertex_client_uses_tracking_headers():
|
||||
"""Tests that Claude vertex client is called with tracking headers."""
|
||||
with mock.patch.object(
|
||||
anthropic_llm, "AsyncAnthropicVertex", autospec=True
|
||||
) as mock_anthropic_vertex:
|
||||
with mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"GOOGLE_CLOUD_PROJECT": "test-project",
|
||||
"GOOGLE_CLOUD_LOCATION": "us-central1",
|
||||
},
|
||||
):
|
||||
instance = Claude(model="claude-3-5-sonnet-v2@20241022")
|
||||
_ = instance._anthropic_client
|
||||
mock_anthropic_vertex.assert_called_once()
|
||||
_, kwargs = mock_anthropic_vertex.call_args
|
||||
assert "default_headers" in kwargs
|
||||
assert "x-goog-api-client" in kwargs["default_headers"]
|
||||
assert "user-agent" in kwargs["default_headers"]
|
||||
assert (
|
||||
f"google-adk/{adk_version.__version__}"
|
||||
in kwargs["default_headers"]["user-agent"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_max_tokens(
|
||||
llm_request, generate_content_response, generate_llm_response
|
||||
|
||||
@@ -31,7 +31,6 @@ from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
|
||||
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
|
||||
from google.adk.utils._google_client_headers import get_tracking_headers
|
||||
from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
from google.genai import types
|
||||
from google.genai.errors import ClientError
|
||||
@@ -470,7 +469,7 @@ async def test_generate_content_async_with_custom_headers(
|
||||
"""Test that tracking headers are updated when custom headers are provided."""
|
||||
# Add custom headers to the request config
|
||||
custom_headers = {"custom-header": "custom-value"}
|
||||
tracking_headers = get_tracking_headers()
|
||||
tracking_headers = gemini_llm._tracking_headers()
|
||||
for key in tracking_headers:
|
||||
custom_headers[key] = "custom " + tracking_headers[key]
|
||||
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
|
||||
@@ -495,7 +494,7 @@ async def test_generate_content_async_with_custom_headers(
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
for key, value in config_arg.http_options.headers.items():
|
||||
tracking_headers = get_tracking_headers()
|
||||
tracking_headers = gemini_llm._tracking_headers()
|
||||
if key in tracking_headers:
|
||||
assert value == tracking_headers[key] + " custom"
|
||||
else:
|
||||
@@ -546,7 +545,7 @@ async def test_generate_content_async_stream_with_custom_headers(
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
expected_headers = custom_headers.copy()
|
||||
expected_headers.update(get_tracking_headers())
|
||||
expected_headers.update(gemini_llm._tracking_headers())
|
||||
assert config_arg.http_options.headers == expected_headers
|
||||
|
||||
assert len(responses) == 2
|
||||
@@ -600,7 +599,7 @@ async def test_generate_content_async_patches_tracking_headers(
|
||||
assert final_config.http_options is not None
|
||||
assert (
|
||||
final_config.http_options.headers["x-goog-api-client"]
|
||||
== get_tracking_headers()["x-goog-api-client"]
|
||||
== gemini_llm._tracking_headers()["x-goog-api-client"]
|
||||
)
|
||||
|
||||
assert len(responses) == 2 if stream else 1
|
||||
@@ -634,7 +633,7 @@ def test_live_api_client_properties(gemini_llm):
|
||||
assert http_options.api_version == "v1beta1"
|
||||
|
||||
# Check that tracking headers are included
|
||||
tracking_headers = get_tracking_headers()
|
||||
tracking_headers = gemini_llm._tracking_headers()
|
||||
for key, value in tracking_headers.items():
|
||||
assert key in http_options.headers
|
||||
assert value in http_options.headers[key]
|
||||
@@ -672,7 +671,7 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request):
|
||||
|
||||
# Verify that tracking headers were merged with custom headers
|
||||
expected_headers = custom_headers.copy()
|
||||
expected_headers.update(get_tracking_headers())
|
||||
expected_headers.update(gemini_llm._tracking_headers())
|
||||
assert config_arg.http_options.headers == expected_headers
|
||||
|
||||
# Verify that API version was set
|
||||
|
||||
@@ -2447,12 +2447,11 @@ def test_model_response_to_chunk(
|
||||
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
||||
lite_llm_instance = LiteLlm(
|
||||
# valid args
|
||||
model="vertex_ai/test_model",
|
||||
model="test_model",
|
||||
llm_client=mock_client,
|
||||
api_key="test_key",
|
||||
api_base="some://url",
|
||||
api_version="2024-09-12",
|
||||
headers={"custom": "header"}, # Add custom header to test merge
|
||||
# invalid args (ignored)
|
||||
stream=True,
|
||||
messages=[{"role": "invalid", "content": "invalid"}],
|
||||
@@ -2479,43 +2478,13 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
|
||||
assert kwargs["model"] == "vertex_ai/test_model"
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert "stream" not in kwargs
|
||||
assert "llm_client" not in kwargs
|
||||
assert kwargs["api_base"] == "some://url"
|
||||
assert "headers" in kwargs
|
||||
assert kwargs["headers"]["custom"] == "header"
|
||||
assert "x-goog-api-client" in kwargs["headers"]
|
||||
assert "user-agent" in kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_additional_args_non_vertex(
|
||||
mock_acompletion, mock_client
|
||||
):
|
||||
"""Test that tracking headers are not added for non-Vertex AI models."""
|
||||
lite_llm_instance = LiteLlm(
|
||||
model="openai/gpt-4o",
|
||||
llm_client=mock_client,
|
||||
api_key="test_key",
|
||||
headers={"custom": "header"},
|
||||
)
|
||||
|
||||
async for _ in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
||||
):
|
||||
pass
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["model"] == "openai/gpt-4o"
|
||||
assert "headers" in kwargs
|
||||
assert kwargs["headers"]["custom"] == "header"
|
||||
assert "x-goog-api-client" not in kwargs["headers"]
|
||||
assert "user-agent" not in kwargs["headers"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
from google.adk import version
|
||||
from google.adk.utils import _google_client_headers
|
||||
import pytest
|
||||
|
||||
_EXPECTED_BASE_HEADER = (
|
||||
f"google-adk/{version.__version__} gl-python/{sys.version.split()[0]}"
|
||||
)
|
||||
|
||||
|
||||
def test_get_tracking_headers():
|
||||
"""Test get_tracking_headers returns correct headers."""
|
||||
headers = _google_client_headers.get_tracking_headers()
|
||||
assert headers == {
|
||||
"x-goog-api-client": _EXPECTED_BASE_HEADER,
|
||||
"user-agent": _EXPECTED_BASE_HEADER,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_headers, expected_headers",
|
||||
[
|
||||
(
|
||||
None,
|
||||
{
|
||||
"x-goog-api-client": _EXPECTED_BASE_HEADER,
|
||||
"user-agent": _EXPECTED_BASE_HEADER,
|
||||
},
|
||||
),
|
||||
(
|
||||
{},
|
||||
{
|
||||
"x-goog-api-client": _EXPECTED_BASE_HEADER,
|
||||
"user-agent": _EXPECTED_BASE_HEADER,
|
||||
},
|
||||
),
|
||||
(
|
||||
{"x-goog-api-client": "label3 label4"},
|
||||
{
|
||||
"x-goog-api-client": f"{_EXPECTED_BASE_HEADER} label3 label4",
|
||||
"user-agent": _EXPECTED_BASE_HEADER,
|
||||
},
|
||||
),
|
||||
(
|
||||
{"x-goog-api-client": f"gl-python/{sys.version.split()[0]} label3"},
|
||||
{
|
||||
"x-goog-api-client": f"{_EXPECTED_BASE_HEADER} label3",
|
||||
"user-agent": _EXPECTED_BASE_HEADER,
|
||||
},
|
||||
),
|
||||
(
|
||||
{"other-header": "value"},
|
||||
{
|
||||
"x-goog-api-client": _EXPECTED_BASE_HEADER,
|
||||
"user-agent": _EXPECTED_BASE_HEADER,
|
||||
"other-header": "value",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_tracking_headers(input_headers, expected_headers):
|
||||
"""Test merge_tracking_headers with various inputs."""
|
||||
headers = _google_client_headers.merge_tracking_headers(input_headers)
|
||||
assert headers == expected_headers
|
||||
Reference in New Issue
Block a user