diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index d3c92dfe..92f94c75 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -25,7 +25,6 @@ from google.adk import version as adk_version from google.genai import types from typing_extensions import override -from ..utils._google_client_headers import merge_tracking_headers from ..utils.env_utils import is_env_enabled from .google_llm import Gemini @@ -146,7 +145,7 @@ class ApigeeLlm(Gemini): kwargs_for_http_options['api_version'] = self._api_version http_options = types.HttpOptions( base_url=self._proxy_url, - headers=merge_tracking_headers(self._custom_headers), + headers=self._merge_tracking_headers(self._custom_headers), retry_options=self.retry_options, **kwargs_for_http_options, ) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 78bf535d..ab65bf07 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -192,7 +192,7 @@ class Gemini(BaseLlm): if llm_request.config: if not llm_request.config.http_options: llm_request.config.http_options = types.HttpOptions() - llm_request.config.http_options.headers = merge_tracking_headers( + llm_request.config.http_options.headers = self._merge_tracking_headers( llm_request.config.http_options.headers ) @@ -303,7 +303,7 @@ class Gemini(BaseLlm): return Client( http_options=types.HttpOptions( - headers=get_tracking_headers(), + headers=self._tracking_headers(), retry_options=self.retry_options, ) ) @@ -316,6 +316,9 @@ class Gemini(BaseLlm): else GoogleLLMVariant.GEMINI_API ) + def _tracking_headers(self) -> dict[str, str]: + return get_tracking_headers() + @cached_property def _live_api_version(self) -> str: if self._api_backend == GoogleLLMVariant.VERTEX_AI: @@ -331,7 +334,7 @@ class Gemini(BaseLlm): return Client( http_options=types.HttpOptions( - headers=get_tracking_headers(), api_version=self._live_api_version + headers=self._tracking_headers(), api_version=self._live_api_version ) ) @@ -355,7 +358,7 @@ class Gemini(BaseLlm): if not llm_request.live_connect_config.http_options.headers: llm_request.live_connect_config.http_options.headers = {} llm_request.live_connect_config.http_options.headers = ( - merge_tracking_headers( + self._merge_tracking_headers( llm_request.live_connect_config.http_options.headers ) ) @@ -448,6 +451,10 @@ class Gemini(BaseLlm): llm_request.config.system_instruction = None await self._adapt_computer_use_tool(llm_request) + def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: + """Merge tracking headers to the given headers.""" + return merge_tracking_headers(headers) + def _build_function_declaration_log( func_decl: types.FunctionDeclaration,