From 2a1a41d3ec60376aba14e5a0aa069e645dc121e1 Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Mon, 1 Dec 2025 11:20:08 -0800 Subject: [PATCH] chore: Adding Eval Client label to model calls made during evals Co-authored-by: Ankur Sharma PiperOrigin-RevId: 838857867 --- .../adk/evaluation/local_eval_service.py | 36 +++++---- src/google/adk/models/google_llm.py | 27 +++---- src/google/adk/utils/_client_labels_utils.py | 78 +++++++++++++++++++ tests/unittests/models/test_google_llm.py | 38 +++++---- .../utils/test_client_labels_utils.py | 68 ++++++++++++++++ 5 files changed, 194 insertions(+), 53 deletions(-) create mode 100644 src/google/adk/utils/_client_labels_utils.py create mode 100644 tests/unittests/utils/test_client_labels_utils.py diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 806a8d69..30344702 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -31,6 +31,8 @@ from ..errors.not_found_error import NotFoundError from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService +from ..utils._client_labels_utils import client_label_context +from ..utils._client_labels_utils import EVAL_CLIENT_LABEL from ..utils.feature_decorator import experimental from .base_eval_service import BaseEvalService from .base_eval_service import EvaluateConfig @@ -249,11 +251,12 @@ class LocalEvalService(BaseEvalService): for eval_metric in evaluate_config.eval_metrics: # Perform evaluation of the metric. try: - evaluation_result = await self._evaluate_metric( - eval_metric=eval_metric, - actual_invocations=inference_result.inferences, - expected_invocations=eval_case.conversation, - ) + with client_label_context(EVAL_CLIENT_LABEL): + evaluation_result = await self._evaluate_metric( + eval_metric=eval_metric, + actual_invocations=inference_result.inferences, + expected_invocations=eval_case.conversation, + ) except Exception as e: # We intentionally catch the Exception as we don't want failures to # affect other metric evaluation. @@ -403,17 +406,18 @@ class LocalEvalService(BaseEvalService): ) try: - inferences = ( - await EvaluationGenerator._generate_inferences_from_root_agent( - root_agent=root_agent, - user_simulator=self._user_simulator_provider.provide(eval_case), - initial_session=initial_session, - session_id=session_id, - session_service=self._session_service, - artifact_service=self._artifact_service, - memory_service=self._memory_service, - ) - ) + with client_label_context(EVAL_CLIENT_LABEL): + inferences = ( + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide(eval_case), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + ) + ) inference_result.inferences = inferences inference_result.status = InferenceStatus.SUCCESS diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 90c2fece..93d802ec 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -19,8 +19,6 @@ import contextlib import copy from functools import cached_property import logging -import os -import sys from typing import AsyncGenerator from typing import cast from typing import Optional @@ -31,7 +29,7 @@ from google.genai import types from google.genai.errors import ClientError from typing_extensions import override -from .. import version +from ..utils._client_labels_utils import get_client_labels from ..utils.context_utils import Aclosing from ..utils.streaming_utils import StreamingResponseAggregator from ..utils.variant_utils import GoogleLLMVariant @@ -49,8 +47,7 @@ logger = logging.getLogger('google_adk.' + __name__) _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} -_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' -_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID' + _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ On how to mitigate this issue, please refer to: @@ -245,7 +242,7 @@ class Gemini(BaseLlm): return Client( http_options=types.HttpOptions( - headers=self._tracking_headers, + headers=self._tracking_headers(), retry_options=self.retry_options, ) ) @@ -258,16 +255,12 @@ class Gemini(BaseLlm): else GoogleLLMVariant.GEMINI_API ) - @cached_property def _tracking_headers(self) -> dict[str, str]: - framework_label = f'google-adk/{version.__version__}' - if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): - framework_label = f'{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}' - language_label = 'gl-python/' + sys.version.split()[0] - version_header_value = f'{framework_label} {language_label}' + labels = get_client_labels() + header_value = ' '.join(labels) tracking_headers = { - 'x-goog-api-client': version_header_value, - 'user-agent': version_header_value, + 'x-goog-api-client': header_value, + 'user-agent': header_value, } return tracking_headers @@ -286,7 +279,7 @@ class Gemini(BaseLlm): return Client( http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=self._live_api_version + headers=self._tracking_headers(), api_version=self._live_api_version ) ) @@ -310,7 +303,7 @@ class Gemini(BaseLlm): if not llm_request.live_connect_config.http_options.headers: llm_request.live_connect_config.http_options.headers = {} llm_request.live_connect_config.http_options.headers.update( - self._tracking_headers + self._tracking_headers() ) llm_request.live_connect_config.http_options.api_version = ( self._live_api_version @@ -397,7 +390,7 @@ class Gemini(BaseLlm): def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: """Merge tracking headers to the given headers.""" headers = headers or {} - for key, tracking_header_value in self._tracking_headers.items(): + for key, tracking_header_value in self._tracking_headers().items(): custom_value = headers.get(key, None) if not custom_value: headers[key] = tracking_header_value diff --git a/src/google/adk/utils/_client_labels_utils.py b/src/google/adk/utils/_client_labels_utils.py new file mode 100644 index 00000000..72858c3c --- /dev/null +++ b/src/google/adk/utils/_client_labels_utils.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +import contextvars +import os +import sys +from typing import List + +from .. import version + +_ADK_LABEL = "google-adk" +_LANGUAGE_LABEL = "gl-python" +_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine" +_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID" + + +EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}" +"""Label used to denote calls emerging to external system as a part of Evals.""" + +# The ContextVar holds client label collected for the current request. +_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar( + "_LABEL_CONTEXT", default=None +) + + +def _get_default_labels() -> List[str]: + """Returns a list of labels that are always added.""" + framework_label = f"{_ADK_LABEL}/{version.__version__}" + + if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): + framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}" + + language_label = f"{_LANGUAGE_LABEL}/" + sys.version.split()[0] + return [framework_label, language_label] + + +@contextmanager +def client_label_context(client_label: str): + """Runs the operation within the context of the given client label.""" + current_client_label = _LABEL_CONTEXT.get() + + if current_client_label is not None: + raise ValueError( + "Client label already exists. You can only add one client label." + ) + + token = _LABEL_CONTEXT.set(client_label) + + try: + yield + finally: + # Restore the previous state of the context variable + _LABEL_CONTEXT.reset(token) + + +def get_client_labels() -> List[str]: + """Returns the current list of client labels that can be added to HTTP Headers.""" + labels = _get_default_labels() + current_client_label = _LABEL_CONTEXT.get() + + if current_client_label: + labels.append(current_client_label) + + return labels diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index f2419daf..ddf1b076 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -22,8 +22,6 @@ from google.adk import version as adk_version from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.models.cache_metadata import CacheMetadata from google.adk.models.gemini_llm_connection import GeminiLlmConnection -from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME -from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.models.google_llm import _build_function_declaration_log from google.adk.models.google_llm import _build_request_log from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE @@ -31,6 +29,8 @@ from google.adk.models.google_llm import _ResourceExhaustedError from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME +from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.errors import ClientError @@ -142,13 +142,6 @@ def llm_request_with_computer_use(): ) -@pytest.fixture -def mock_os_environ(): - initial_env = os.environ.copy() - with mock.patch.dict(os.environ, initial_env, clear=False) as m: - yield m - - def test_supported_models(): models = Gemini.supported_models() assert len(models) == 4 @@ -193,12 +186,15 @@ def test_client_version_header(): ) -def test_client_version_header_with_agent_engine(mock_os_environ): - os.environ[_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME] = "my_test_project" +def test_client_version_header_with_agent_engine(monkeypatch): + monkeypatch.setenv( + _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, "my_test_project" + ) model = Gemini(model="gemini-1.5-flash") client = model.api_client - # Check that ADK version with telemetry tag and Python version are present in headers + # Check that ADK version with telemetry tag and Python version are present in + # headers adk_version_with_telemetry = ( f"google-adk/{adk_version.__version__}+{_AGENT_ENGINE_TELEMETRY_TAG}" ) @@ -473,8 +469,9 @@ async def test_generate_content_async_with_custom_headers( """Test that tracking headers are updated when custom headers are provided.""" # Add custom headers to the request config custom_headers = {"custom-header": "custom-value"} - for key in gemini_llm._tracking_headers: - custom_headers[key] = "custom " + gemini_llm._tracking_headers[key] + tracking_headers = gemini_llm._tracking_headers() + for key in tracking_headers: + custom_headers[key] = "custom " + tracking_headers[key] llm_request.config.http_options = types.HttpOptions(headers=custom_headers) with mock.patch.object(gemini_llm, "api_client") as mock_client: @@ -497,8 +494,9 @@ async def test_generate_content_async_with_custom_headers( config_arg = call_args.kwargs["config"] for key, value in config_arg.http_options.headers.items(): - if key in gemini_llm._tracking_headers: - assert value == gemini_llm._tracking_headers[key] + " custom" + tracking_headers = gemini_llm._tracking_headers() + if key in tracking_headers: + assert value == tracking_headers[key] + " custom" else: assert value == custom_headers[key] @@ -547,7 +545,7 @@ async def test_generate_content_async_stream_with_custom_headers( config_arg = call_args.kwargs["config"] expected_headers = custom_headers.copy() - expected_headers.update(gemini_llm._tracking_headers) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers assert len(responses) == 2 @@ -601,7 +599,7 @@ async def test_generate_content_async_patches_tracking_headers( assert final_config.http_options is not None assert ( final_config.http_options.headers["x-goog-api-client"] - == gemini_llm._tracking_headers["x-goog-api-client"] + == gemini_llm._tracking_headers()["x-goog-api-client"] ) assert len(responses) == 2 if stream else 1 @@ -635,7 +633,7 @@ def test_live_api_client_properties(gemini_llm): assert http_options.api_version == "v1beta1" # Check that tracking headers are included - tracking_headers = gemini_llm._tracking_headers + tracking_headers = gemini_llm._tracking_headers() for key, value in tracking_headers.items(): assert key in http_options.headers assert value in http_options.headers[key] @@ -673,7 +671,7 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request): # Verify that tracking headers were merged with custom headers expected_headers = custom_headers.copy() - expected_headers.update(gemini_llm._tracking_headers) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers # Verify that API version was set diff --git a/tests/unittests/utils/test_client_labels_utils.py b/tests/unittests/utils/test_client_labels_utils.py new file mode 100644 index 00000000..b1d6acb0 --- /dev/null +++ b/tests/unittests/utils/test_client_labels_utils.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from google.adk import version +from google.adk.utils import _client_labels_utils +import pytest + + +def test_get_client_labels_default(): + """Test get_client_labels returns default labels.""" + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 2 + assert f"google-adk/{version.__version__}" == labels[0] + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + + +def test_get_client_labels_with_agent_engine_id(monkeypatch): + """Test get_client_labels returns agent engine tag when env var is set.""" + monkeypatch.setenv( + _client_labels_utils._AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, + "test-agent-id", + ) + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 2 + assert ( + f"google-adk/{version.__version__}+{_client_labels_utils._AGENT_ENGINE_TELEMETRY_TAG}" + == labels[0] + ) + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + + +def test_get_client_labels_with_context(): + """Test get_client_labels includes label from context.""" + with _client_labels_utils.client_label_context("my-label/1.0"): + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 3 + assert f"google-adk/{version.__version__}" == labels[0] + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + assert "my-label/1.0" == labels[2] + + +def test_client_label_context_nested_error(): + """Test client_label_context raises error when nested.""" + with pytest.raises(ValueError, match="Client label already exists"): + with _client_labels_utils.client_label_context("my-label/1.0"): + with _client_labels_utils.client_label_context("another-label/1.0"): + pass + + +def test_eval_client_label(): + """Test EVAL_CLIENT_LABEL has correct format.""" + assert ( + f"google-adk-eval/{version.__version__}" + == _client_labels_utils.EVAL_CLIENT_LABEL + )