chore: Adding Eval Client label to model calls made during evals

Co-authored-by: Ankur Sharma <ankusharma@google.com>
PiperOrigin-RevId: 838857867
This commit is contained in:
Ankur Sharma
2025-12-01 11:20:08 -08:00
committed by Copybara-Service
parent 8e82838f1e
commit 2a1a41d3ec
5 changed files with 194 additions and 53 deletions
+20 -16
View File
@@ -31,6 +31,8 @@ from ..errors.not_found_error import NotFoundError
from ..memory.base_memory_service import BaseMemoryService
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils._client_labels_utils import client_label_context
from ..utils._client_labels_utils import EVAL_CLIENT_LABEL
from ..utils.feature_decorator import experimental
from .base_eval_service import BaseEvalService
from .base_eval_service import EvaluateConfig
@@ -249,11 +251,12 @@ class LocalEvalService(BaseEvalService):
for eval_metric in evaluate_config.eval_metrics:
# Perform evaluation of the metric.
try:
evaluation_result = await self._evaluate_metric(
eval_metric=eval_metric,
actual_invocations=inference_result.inferences,
expected_invocations=eval_case.conversation,
)
with client_label_context(EVAL_CLIENT_LABEL):
evaluation_result = await self._evaluate_metric(
eval_metric=eval_metric,
actual_invocations=inference_result.inferences,
expected_invocations=eval_case.conversation,
)
except Exception as e:
# We intentionally catch the Exception as we don't want failures to
# affect other metric evaluation.
@@ -403,17 +406,18 @@ class LocalEvalService(BaseEvalService):
)
try:
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
root_agent=root_agent,
user_simulator=self._user_simulator_provider.provide(eval_case),
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
memory_service=self._memory_service,
)
)
with client_label_context(EVAL_CLIENT_LABEL):
inferences = (
await EvaluationGenerator._generate_inferences_from_root_agent(
root_agent=root_agent,
user_simulator=self._user_simulator_provider.provide(eval_case),
initial_session=initial_session,
session_id=session_id,
session_service=self._session_service,
artifact_service=self._artifact_service,
memory_service=self._memory_service,
)
)
inference_result.inferences = inferences
inference_result.status = InferenceStatus.SUCCESS
+10 -17
View File
@@ -19,8 +19,6 @@ import contextlib
import copy
from functools import cached_property
import logging
import os
import sys
from typing import AsyncGenerator
from typing import cast
from typing import Optional
@@ -31,7 +29,7 @@ from google.genai import types
from google.genai.errors import ClientError
from typing_extensions import override
from .. import version
from ..utils._client_labels_utils import get_client_labels
from ..utils.context_utils import Aclosing
from ..utils.streaming_utils import StreamingResponseAggregator
from ..utils.variant_utils import GoogleLLMVariant
@@ -49,8 +47,7 @@ logger = logging.getLogger('google_adk.' + __name__)
_NEW_LINE = '\n'
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine'
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'
_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """
On how to mitigate this issue, please refer to:
@@ -245,7 +242,7 @@ class Gemini(BaseLlm):
return Client(
http_options=types.HttpOptions(
headers=self._tracking_headers,
headers=self._tracking_headers(),
retry_options=self.retry_options,
)
)
@@ -258,16 +255,12 @@ class Gemini(BaseLlm):
else GoogleLLMVariant.GEMINI_API
)
@cached_property
def _tracking_headers(self) -> dict[str, str]:
framework_label = f'google-adk/{version.__version__}'
if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME):
framework_label = f'{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}'
language_label = 'gl-python/' + sys.version.split()[0]
version_header_value = f'{framework_label} {language_label}'
labels = get_client_labels()
header_value = ' '.join(labels)
tracking_headers = {
'x-goog-api-client': version_header_value,
'user-agent': version_header_value,
'x-goog-api-client': header_value,
'user-agent': header_value,
}
return tracking_headers
@@ -286,7 +279,7 @@ class Gemini(BaseLlm):
return Client(
http_options=types.HttpOptions(
headers=self._tracking_headers, api_version=self._live_api_version
headers=self._tracking_headers(), api_version=self._live_api_version
)
)
@@ -310,7 +303,7 @@ class Gemini(BaseLlm):
if not llm_request.live_connect_config.http_options.headers:
llm_request.live_connect_config.http_options.headers = {}
llm_request.live_connect_config.http_options.headers.update(
self._tracking_headers
self._tracking_headers()
)
llm_request.live_connect_config.http_options.api_version = (
self._live_api_version
@@ -397,7 +390,7 @@ class Gemini(BaseLlm):
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""Merge tracking headers to the given headers."""
headers = headers or {}
for key, tracking_header_value in self._tracking_headers.items():
for key, tracking_header_value in self._tracking_headers().items():
custom_value = headers.get(key, None)
if not custom_value:
headers[key] = tracking_header_value
@@ -0,0 +1,78 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from contextlib import contextmanager
import contextvars
import os
import sys
from typing import List
from .. import version
_ADK_LABEL = "google-adk"
_LANGUAGE_LABEL = "gl-python"
_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine"
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID"
EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}"
"""Label used to denote calls emerging to external system as a part of Evals."""
# The ContextVar holds client label collected for the current request.
_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar(
"_LABEL_CONTEXT", default=None
)
def _get_default_labels() -> List[str]:
"""Returns a list of labels that are always added."""
framework_label = f"{_ADK_LABEL}/{version.__version__}"
if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME):
framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}"
language_label = f"{_LANGUAGE_LABEL}/" + sys.version.split()[0]
return [framework_label, language_label]
@contextmanager
def client_label_context(client_label: str):
"""Runs the operation within the context of the given client label."""
current_client_label = _LABEL_CONTEXT.get()
if current_client_label is not None:
raise ValueError(
"Client label already exists. You can only add one client label."
)
token = _LABEL_CONTEXT.set(client_label)
try:
yield
finally:
# Restore the previous state of the context variable
_LABEL_CONTEXT.reset(token)
def get_client_labels() -> List[str]:
"""Returns the current list of client labels that can be added to HTTP Headers."""
labels = _get_default_labels()
current_client_label = _LABEL_CONTEXT.get()
if current_client_label:
labels.append(current_client_label)
return labels
+18 -20
View File
@@ -22,8 +22,6 @@ from google.adk import version as adk_version
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.models.cache_metadata import CacheMetadata
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG
from google.adk.models.google_llm import _build_function_declaration_log
from google.adk.models.google_llm import _build_request_log
from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE
@@ -31,6 +29,8 @@ from google.adk.models.google_llm import _ResourceExhaustedError
from google.adk.models.google_llm import Gemini
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai.errors import ClientError
@@ -142,13 +142,6 @@ def llm_request_with_computer_use():
)
@pytest.fixture
def mock_os_environ():
initial_env = os.environ.copy()
with mock.patch.dict(os.environ, initial_env, clear=False) as m:
yield m
def test_supported_models():
models = Gemini.supported_models()
assert len(models) == 4
@@ -193,12 +186,15 @@ def test_client_version_header():
)
def test_client_version_header_with_agent_engine(mock_os_environ):
os.environ[_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME] = "my_test_project"
def test_client_version_header_with_agent_engine(monkeypatch):
monkeypatch.setenv(
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, "my_test_project"
)
model = Gemini(model="gemini-1.5-flash")
client = model.api_client
# Check that ADK version with telemetry tag and Python version are present in headers
# Check that ADK version with telemetry tag and Python version are present in
# headers
adk_version_with_telemetry = (
f"google-adk/{adk_version.__version__}+{_AGENT_ENGINE_TELEMETRY_TAG}"
)
@@ -473,8 +469,9 @@ async def test_generate_content_async_with_custom_headers(
"""Test that tracking headers are updated when custom headers are provided."""
# Add custom headers to the request config
custom_headers = {"custom-header": "custom-value"}
for key in gemini_llm._tracking_headers:
custom_headers[key] = "custom " + gemini_llm._tracking_headers[key]
tracking_headers = gemini_llm._tracking_headers()
for key in tracking_headers:
custom_headers[key] = "custom " + tracking_headers[key]
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
with mock.patch.object(gemini_llm, "api_client") as mock_client:
@@ -497,8 +494,9 @@ async def test_generate_content_async_with_custom_headers(
config_arg = call_args.kwargs["config"]
for key, value in config_arg.http_options.headers.items():
if key in gemini_llm._tracking_headers:
assert value == gemini_llm._tracking_headers[key] + " custom"
tracking_headers = gemini_llm._tracking_headers()
if key in tracking_headers:
assert value == tracking_headers[key] + " custom"
else:
assert value == custom_headers[key]
@@ -547,7 +545,7 @@ async def test_generate_content_async_stream_with_custom_headers(
config_arg = call_args.kwargs["config"]
expected_headers = custom_headers.copy()
expected_headers.update(gemini_llm._tracking_headers)
expected_headers.update(gemini_llm._tracking_headers())
assert config_arg.http_options.headers == expected_headers
assert len(responses) == 2
@@ -601,7 +599,7 @@ async def test_generate_content_async_patches_tracking_headers(
assert final_config.http_options is not None
assert (
final_config.http_options.headers["x-goog-api-client"]
== gemini_llm._tracking_headers["x-goog-api-client"]
== gemini_llm._tracking_headers()["x-goog-api-client"]
)
assert len(responses) == 2 if stream else 1
@@ -635,7 +633,7 @@ def test_live_api_client_properties(gemini_llm):
assert http_options.api_version == "v1beta1"
# Check that tracking headers are included
tracking_headers = gemini_llm._tracking_headers
tracking_headers = gemini_llm._tracking_headers()
for key, value in tracking_headers.items():
assert key in http_options.headers
assert value in http_options.headers[key]
@@ -673,7 +671,7 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request):
# Verify that tracking headers were merged with custom headers
expected_headers = custom_headers.copy()
expected_headers.update(gemini_llm._tracking_headers)
expected_headers.update(gemini_llm._tracking_headers())
assert config_arg.http_options.headers == expected_headers
# Verify that API version was set
@@ -0,0 +1,68 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from google.adk import version
from google.adk.utils import _client_labels_utils
import pytest
def test_get_client_labels_default():
"""Test get_client_labels returns default labels."""
labels = _client_labels_utils.get_client_labels()
assert len(labels) == 2
assert f"google-adk/{version.__version__}" == labels[0]
assert f"gl-python/{sys.version.split()[0]}" == labels[1]
def test_get_client_labels_with_agent_engine_id(monkeypatch):
"""Test get_client_labels returns agent engine tag when env var is set."""
monkeypatch.setenv(
_client_labels_utils._AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME,
"test-agent-id",
)
labels = _client_labels_utils.get_client_labels()
assert len(labels) == 2
assert (
f"google-adk/{version.__version__}+{_client_labels_utils._AGENT_ENGINE_TELEMETRY_TAG}"
== labels[0]
)
assert f"gl-python/{sys.version.split()[0]}" == labels[1]
def test_get_client_labels_with_context():
"""Test get_client_labels includes label from context."""
with _client_labels_utils.client_label_context("my-label/1.0"):
labels = _client_labels_utils.get_client_labels()
assert len(labels) == 3
assert f"google-adk/{version.__version__}" == labels[0]
assert f"gl-python/{sys.version.split()[0]}" == labels[1]
assert "my-label/1.0" == labels[2]
def test_client_label_context_nested_error():
"""Test client_label_context raises error when nested."""
with pytest.raises(ValueError, match="Client label already exists"):
with _client_labels_utils.client_label_context("my-label/1.0"):
with _client_labels_utils.client_label_context("another-label/1.0"):
pass
def test_eval_client_label():
"""Test EVAL_CLIENT_LABEL has correct format."""
assert (
f"google-adk-eval/{version.__version__}"
== _client_labels_utils.EVAL_CLIENT_LABEL
)