From 23d330eef1ed23696b56395ef2dee765eec3310f Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Thu, 8 Jan 2026 16:22:18 -0800 Subject: [PATCH] feat: Include model ID with token usage for live events This allows users to track token usage data per model and fixes https://github.com/google/adk-python/issues/4084. Co-authored-by: Kathy Wu PiperOrigin-RevId: 853925212 --- .../adk/models/gemini_llm_connection.py | 8 ++++- src/google/adk/models/google_llm.py | 6 +++- .../models/test_gemini_llm_connection.py | 11 +++++-- tests/unittests/models/test_google_llm.py | 29 ++++++++++++------- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 55d4b62e..327157e2 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -41,11 +41,13 @@ class GeminiLlmConnection(BaseLlmConnection): self, gemini_session: live.AsyncSession, api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI, + model_version: str | None = None, ): self._gemini_session = gemini_session self._input_transcription_text: str = '' self._output_transcription_text: str = '' self._api_backend = api_backend + self._model_version = model_version async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -162,7 +164,11 @@ class GeminiLlmConnection(BaseLlmConnection): async for message in agen: logger.debug('Got LLM Live message: %s', message) if message.usage_metadata: - yield LlmResponse(usage_metadata=message.usage_metadata) + # Tracks token usage data per model. + yield LlmResponse( + usage_metadata=message.usage_metadata, + model_version=self._model_version, + ) if message.server_content: content = message.server_content.model_turn if content and content.parts: diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 9261fada..c38f854c 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -402,7 +402,11 @@ class Gemini(BaseLlm): async with self._live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: - yield GeminiLlmConnection(live_session, api_backend=self._api_backend) + yield GeminiLlmConnection( + live_session, + api_backend=self._api_backend, + model_version=llm_request.model, + ) async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: """Adapt the google computer use predefined functions to the adk computer use toolset.""" diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 19000760..de8f4f9d 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -19,6 +19,8 @@ from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types import pytest +MODEL_VERSION = 'gemini-2.5-pro' + @pytest.fixture def mock_gemini_session(): @@ -30,7 +32,9 @@ def mock_gemini_session(): def gemini_connection(mock_gemini_session): """GeminiLlmConnection instance with mocked session.""" return GeminiLlmConnection( - mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + mock_gemini_session, + api_backend=GoogleLLMVariant.VERTEX_AI, + model_version=MODEL_VERSION, ) @@ -38,7 +42,9 @@ def gemini_connection(mock_gemini_session): def gemini_api_connection(mock_gemini_session): """GeminiLlmConnection instance with mocked session for Gemini API.""" return GeminiLlmConnection( - mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API + mock_gemini_session, + api_backend=GoogleLLMVariant.GEMINI_API, + model_version=MODEL_VERSION, ) @@ -215,6 +221,7 @@ async def test_receive_usage_metadata_and_server_content( usage_response = next((r for r in responses if r.usage_metadata), None) assert usage_response is not None + assert usage_response.model_version == MODEL_VERSION content_response = next((r for r in responses if r.content), None) assert content_response is not None diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index c42260a2..47ff33bc 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -705,20 +705,27 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request): mock_live_client.aio.live.connect.return_value = MockLiveConnect() - async with gemini_llm.connect(llm_request) as connection: - # Verify that the connect method was called with the right config - mock_live_client.aio.live.connect.assert_called_once() - call_args = mock_live_client.aio.live.connect.call_args - config_arg = call_args.kwargs["config"] + with mock.patch( + "google.adk.models.google_llm.GeminiLlmConnection" + ) as MockGeminiLlmConnection: + async with gemini_llm.connect(llm_request) as connection: + # Verify that the connect method was called with the right config + mock_live_client.aio.live.connect.assert_called_once() + call_args = mock_live_client.aio.live.connect.call_args + config_arg = call_args.kwargs["config"] - # Verify that http_options remains None since no custom headers were provided - assert config_arg.http_options is None + # Verify that http_options remains None since no custom headers were provided + assert config_arg.http_options is None - # Verify that system instruction and tools were still set - assert config_arg.system_instruction is not None - assert config_arg.tools == llm_request.config.tools + # Verify that system instruction and tools were still set + assert config_arg.system_instruction is not None + assert config_arg.tools == llm_request.config.tools - assert isinstance(connection, GeminiLlmConnection) + MockGeminiLlmConnection.assert_called_once_with( + mock_live_session, + api_backend=gemini_llm._api_backend, + model_version=llm_request.model, + ) @pytest.mark.parametrize(