You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Adding Eval Client label to model calls made during evals
Co-authored-by: Ankur Sharma <ankusharma@google.com> PiperOrigin-RevId: 838857867
This commit is contained in:
committed by
Copybara-Service
parent
8e82838f1e
commit
2a1a41d3ec
@@ -31,6 +31,8 @@ from ..errors.not_found_error import NotFoundError
|
||||
from ..memory.base_memory_service import BaseMemoryService
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..utils._client_labels_utils import client_label_context
|
||||
from ..utils._client_labels_utils import EVAL_CLIENT_LABEL
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .base_eval_service import BaseEvalService
|
||||
from .base_eval_service import EvaluateConfig
|
||||
@@ -249,11 +251,12 @@ class LocalEvalService(BaseEvalService):
|
||||
for eval_metric in evaluate_config.eval_metrics:
|
||||
# Perform evaluation of the metric.
|
||||
try:
|
||||
evaluation_result = await self._evaluate_metric(
|
||||
eval_metric=eval_metric,
|
||||
actual_invocations=inference_result.inferences,
|
||||
expected_invocations=eval_case.conversation,
|
||||
)
|
||||
with client_label_context(EVAL_CLIENT_LABEL):
|
||||
evaluation_result = await self._evaluate_metric(
|
||||
eval_metric=eval_metric,
|
||||
actual_invocations=inference_result.inferences,
|
||||
expected_invocations=eval_case.conversation,
|
||||
)
|
||||
except Exception as e:
|
||||
# We intentionally catch the Exception as we don't want failures to
|
||||
# affect other metric evaluation.
|
||||
@@ -403,17 +406,18 @@ class LocalEvalService(BaseEvalService):
|
||||
)
|
||||
|
||||
try:
|
||||
inferences = (
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
root_agent=root_agent,
|
||||
user_simulator=self._user_simulator_provider.provide(eval_case),
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=self._session_service,
|
||||
artifact_service=self._artifact_service,
|
||||
memory_service=self._memory_service,
|
||||
)
|
||||
)
|
||||
with client_label_context(EVAL_CLIENT_LABEL):
|
||||
inferences = (
|
||||
await EvaluationGenerator._generate_inferences_from_root_agent(
|
||||
root_agent=root_agent,
|
||||
user_simulator=self._user_simulator_provider.provide(eval_case),
|
||||
initial_session=initial_session,
|
||||
session_id=session_id,
|
||||
session_service=self._session_service,
|
||||
artifact_service=self._artifact_service,
|
||||
memory_service=self._memory_service,
|
||||
)
|
||||
)
|
||||
|
||||
inference_result.inferences = inferences
|
||||
inference_result.status = InferenceStatus.SUCCESS
|
||||
|
||||
@@ -19,8 +19,6 @@ import contextlib
|
||||
import copy
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -31,7 +29,7 @@ from google.genai import types
|
||||
from google.genai.errors import ClientError
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
from ..utils._client_labels_utils import get_client_labels
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.streaming_utils import StreamingResponseAggregator
|
||||
from ..utils.variant_utils import GoogleLLMVariant
|
||||
@@ -49,8 +47,7 @@ logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
||||
_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine'
|
||||
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID'
|
||||
|
||||
|
||||
_RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """
|
||||
On how to mitigate this issue, please refer to:
|
||||
@@ -245,7 +242,7 @@ class Gemini(BaseLlm):
|
||||
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers,
|
||||
headers=self._tracking_headers(),
|
||||
retry_options=self.retry_options,
|
||||
)
|
||||
)
|
||||
@@ -258,16 +255,12 @@ class Gemini(BaseLlm):
|
||||
else GoogleLLMVariant.GEMINI_API
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _tracking_headers(self) -> dict[str, str]:
|
||||
framework_label = f'google-adk/{version.__version__}'
|
||||
if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME):
|
||||
framework_label = f'{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}'
|
||||
language_label = 'gl-python/' + sys.version.split()[0]
|
||||
version_header_value = f'{framework_label} {language_label}'
|
||||
labels = get_client_labels()
|
||||
header_value = ' '.join(labels)
|
||||
tracking_headers = {
|
||||
'x-goog-api-client': version_header_value,
|
||||
'user-agent': version_header_value,
|
||||
'x-goog-api-client': header_value,
|
||||
'user-agent': header_value,
|
||||
}
|
||||
return tracking_headers
|
||||
|
||||
@@ -286,7 +279,7 @@ class Gemini(BaseLlm):
|
||||
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=self._live_api_version
|
||||
headers=self._tracking_headers(), api_version=self._live_api_version
|
||||
)
|
||||
)
|
||||
|
||||
@@ -310,7 +303,7 @@ class Gemini(BaseLlm):
|
||||
if not llm_request.live_connect_config.http_options.headers:
|
||||
llm_request.live_connect_config.http_options.headers = {}
|
||||
llm_request.live_connect_config.http_options.headers.update(
|
||||
self._tracking_headers
|
||||
self._tracking_headers()
|
||||
)
|
||||
llm_request.live_connect_config.http_options.api_version = (
|
||||
self._live_api_version
|
||||
@@ -397,7 +390,7 @@ class Gemini(BaseLlm):
|
||||
def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Merge tracking headers to the given headers."""
|
||||
headers = headers or {}
|
||||
for key, tracking_header_value in self._tracking_headers.items():
|
||||
for key, tracking_header_value in self._tracking_headers().items():
|
||||
custom_value = headers.get(key, None)
|
||||
if not custom_value:
|
||||
headers[key] = tracking_header_value
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import contextvars
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from .. import version
|
||||
|
||||
_ADK_LABEL = "google-adk"
|
||||
_LANGUAGE_LABEL = "gl-python"
|
||||
_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine"
|
||||
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID"
|
||||
|
||||
|
||||
EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}"
|
||||
"""Label used to denote calls emerging to external system as a part of Evals."""
|
||||
|
||||
# The ContextVar holds client label collected for the current request.
|
||||
_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"_LABEL_CONTEXT", default=None
|
||||
)
|
||||
|
||||
|
||||
def _get_default_labels() -> List[str]:
|
||||
"""Returns a list of labels that are always added."""
|
||||
framework_label = f"{_ADK_LABEL}/{version.__version__}"
|
||||
|
||||
if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME):
|
||||
framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}"
|
||||
|
||||
language_label = f"{_LANGUAGE_LABEL}/" + sys.version.split()[0]
|
||||
return [framework_label, language_label]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def client_label_context(client_label: str):
|
||||
"""Runs the operation within the context of the given client label."""
|
||||
current_client_label = _LABEL_CONTEXT.get()
|
||||
|
||||
if current_client_label is not None:
|
||||
raise ValueError(
|
||||
"Client label already exists. You can only add one client label."
|
||||
)
|
||||
|
||||
token = _LABEL_CONTEXT.set(client_label)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore the previous state of the context variable
|
||||
_LABEL_CONTEXT.reset(token)
|
||||
|
||||
|
||||
def get_client_labels() -> List[str]:
|
||||
"""Returns the current list of client labels that can be added to HTTP Headers."""
|
||||
labels = _get_default_labels()
|
||||
current_client_label = _LABEL_CONTEXT.get()
|
||||
|
||||
if current_client_label:
|
||||
labels.append(current_client_label)
|
||||
|
||||
return labels
|
||||
@@ -22,8 +22,6 @@ from google.adk import version as adk_version
|
||||
from google.adk.agents.context_cache_config import ContextCacheConfig
|
||||
from google.adk.models.cache_metadata import CacheMetadata
|
||||
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
|
||||
from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
|
||||
from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG
|
||||
from google.adk.models.google_llm import _build_function_declaration_log
|
||||
from google.adk.models.google_llm import _build_request_log
|
||||
from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE
|
||||
@@ -31,6 +29,8 @@ from google.adk.models.google_llm import _ResourceExhaustedError
|
||||
from google.adk.models.google_llm import Gemini
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.adk.models.llm_response import LlmResponse
|
||||
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME
|
||||
from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG
|
||||
from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
from google.genai import types
|
||||
from google.genai.errors import ClientError
|
||||
@@ -142,13 +142,6 @@ def llm_request_with_computer_use():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_os_environ():
|
||||
initial_env = os.environ.copy()
|
||||
with mock.patch.dict(os.environ, initial_env, clear=False) as m:
|
||||
yield m
|
||||
|
||||
|
||||
def test_supported_models():
|
||||
models = Gemini.supported_models()
|
||||
assert len(models) == 4
|
||||
@@ -193,12 +186,15 @@ def test_client_version_header():
|
||||
)
|
||||
|
||||
|
||||
def test_client_version_header_with_agent_engine(mock_os_environ):
|
||||
os.environ[_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME] = "my_test_project"
|
||||
def test_client_version_header_with_agent_engine(monkeypatch):
|
||||
monkeypatch.setenv(
|
||||
_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, "my_test_project"
|
||||
)
|
||||
model = Gemini(model="gemini-1.5-flash")
|
||||
client = model.api_client
|
||||
|
||||
# Check that ADK version with telemetry tag and Python version are present in headers
|
||||
# Check that ADK version with telemetry tag and Python version are present in
|
||||
# headers
|
||||
adk_version_with_telemetry = (
|
||||
f"google-adk/{adk_version.__version__}+{_AGENT_ENGINE_TELEMETRY_TAG}"
|
||||
)
|
||||
@@ -473,8 +469,9 @@ async def test_generate_content_async_with_custom_headers(
|
||||
"""Test that tracking headers are updated when custom headers are provided."""
|
||||
# Add custom headers to the request config
|
||||
custom_headers = {"custom-header": "custom-value"}
|
||||
for key in gemini_llm._tracking_headers:
|
||||
custom_headers[key] = "custom " + gemini_llm._tracking_headers[key]
|
||||
tracking_headers = gemini_llm._tracking_headers()
|
||||
for key in tracking_headers:
|
||||
custom_headers[key] = "custom " + tracking_headers[key]
|
||||
llm_request.config.http_options = types.HttpOptions(headers=custom_headers)
|
||||
|
||||
with mock.patch.object(gemini_llm, "api_client") as mock_client:
|
||||
@@ -497,8 +494,9 @@ async def test_generate_content_async_with_custom_headers(
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
for key, value in config_arg.http_options.headers.items():
|
||||
if key in gemini_llm._tracking_headers:
|
||||
assert value == gemini_llm._tracking_headers[key] + " custom"
|
||||
tracking_headers = gemini_llm._tracking_headers()
|
||||
if key in tracking_headers:
|
||||
assert value == tracking_headers[key] + " custom"
|
||||
else:
|
||||
assert value == custom_headers[key]
|
||||
|
||||
@@ -547,7 +545,7 @@ async def test_generate_content_async_stream_with_custom_headers(
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
expected_headers = custom_headers.copy()
|
||||
expected_headers.update(gemini_llm._tracking_headers)
|
||||
expected_headers.update(gemini_llm._tracking_headers())
|
||||
assert config_arg.http_options.headers == expected_headers
|
||||
|
||||
assert len(responses) == 2
|
||||
@@ -601,7 +599,7 @@ async def test_generate_content_async_patches_tracking_headers(
|
||||
assert final_config.http_options is not None
|
||||
assert (
|
||||
final_config.http_options.headers["x-goog-api-client"]
|
||||
== gemini_llm._tracking_headers["x-goog-api-client"]
|
||||
== gemini_llm._tracking_headers()["x-goog-api-client"]
|
||||
)
|
||||
|
||||
assert len(responses) == 2 if stream else 1
|
||||
@@ -635,7 +633,7 @@ def test_live_api_client_properties(gemini_llm):
|
||||
assert http_options.api_version == "v1beta1"
|
||||
|
||||
# Check that tracking headers are included
|
||||
tracking_headers = gemini_llm._tracking_headers
|
||||
tracking_headers = gemini_llm._tracking_headers()
|
||||
for key, value in tracking_headers.items():
|
||||
assert key in http_options.headers
|
||||
assert value in http_options.headers[key]
|
||||
@@ -673,7 +671,7 @@ async def test_connect_with_custom_headers(gemini_llm, llm_request):
|
||||
|
||||
# Verify that tracking headers were merged with custom headers
|
||||
expected_headers = custom_headers.copy()
|
||||
expected_headers.update(gemini_llm._tracking_headers)
|
||||
expected_headers.update(gemini_llm._tracking_headers())
|
||||
assert config_arg.http_options.headers == expected_headers
|
||||
|
||||
# Verify that API version was set
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
from google.adk import version
|
||||
from google.adk.utils import _client_labels_utils
|
||||
import pytest
|
||||
|
||||
|
||||
def test_get_client_labels_default():
|
||||
"""Test get_client_labels returns default labels."""
|
||||
labels = _client_labels_utils.get_client_labels()
|
||||
assert len(labels) == 2
|
||||
assert f"google-adk/{version.__version__}" == labels[0]
|
||||
assert f"gl-python/{sys.version.split()[0]}" == labels[1]
|
||||
|
||||
|
||||
def test_get_client_labels_with_agent_engine_id(monkeypatch):
|
||||
"""Test get_client_labels returns agent engine tag when env var is set."""
|
||||
monkeypatch.setenv(
|
||||
_client_labels_utils._AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME,
|
||||
"test-agent-id",
|
||||
)
|
||||
labels = _client_labels_utils.get_client_labels()
|
||||
assert len(labels) == 2
|
||||
assert (
|
||||
f"google-adk/{version.__version__}+{_client_labels_utils._AGENT_ENGINE_TELEMETRY_TAG}"
|
||||
== labels[0]
|
||||
)
|
||||
assert f"gl-python/{sys.version.split()[0]}" == labels[1]
|
||||
|
||||
|
||||
def test_get_client_labels_with_context():
|
||||
"""Test get_client_labels includes label from context."""
|
||||
with _client_labels_utils.client_label_context("my-label/1.0"):
|
||||
labels = _client_labels_utils.get_client_labels()
|
||||
assert len(labels) == 3
|
||||
assert f"google-adk/{version.__version__}" == labels[0]
|
||||
assert f"gl-python/{sys.version.split()[0]}" == labels[1]
|
||||
assert "my-label/1.0" == labels[2]
|
||||
|
||||
|
||||
def test_client_label_context_nested_error():
|
||||
"""Test client_label_context raises error when nested."""
|
||||
with pytest.raises(ValueError, match="Client label already exists"):
|
||||
with _client_labels_utils.client_label_context("my-label/1.0"):
|
||||
with _client_labels_utils.client_label_context("another-label/1.0"):
|
||||
pass
|
||||
|
||||
|
||||
def test_eval_client_label():
|
||||
"""Test EVAL_CLIENT_LABEL has correct format."""
|
||||
assert (
|
||||
f"google-adk-eval/{version.__version__}"
|
||||
== _client_labels_utils.EVAL_CLIENT_LABEL
|
||||
)
|
||||
Reference in New Issue
Block a user