You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Include model ID with token usage for live events
This allows users to track token usage data per model and fixes https://github.com/google/adk-python/issues/4084. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 853925212
This commit is contained in:
committed by
Copybara-Service
parent
b8917bc80e
commit
23d330eef1
@@ -41,11 +41,13 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
self,
|
||||
gemini_session: live.AsyncSession,
|
||||
api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI,
|
||||
model_version: str | None = None,
|
||||
):
|
||||
self._gemini_session = gemini_session
|
||||
self._input_transcription_text: str = ''
|
||||
self._output_transcription_text: str = ''
|
||||
self._api_backend = api_backend
|
||||
self._model_version = model_version
|
||||
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the gemini model.
|
||||
@@ -162,7 +164,11 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
async for message in agen:
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.usage_metadata:
|
||||
yield LlmResponse(usage_metadata=message.usage_metadata)
|
||||
# Tracks token usage data per model.
|
||||
yield LlmResponse(
|
||||
usage_metadata=message.usage_metadata,
|
||||
model_version=self._model_version,
|
||||
)
|
||||
if message.server_content:
|
||||
content = message.server_content.model_turn
|
||||
if content and content.parts:
|
||||
|
||||
@@ -402,7 +402,11 @@ class Gemini(BaseLlm):
|
||||
async with self._live_api_client.aio.live.connect(
|
||||
model=llm_request.model, config=llm_request.live_connect_config
|
||||
) as live_session:
|
||||
yield GeminiLlmConnection(live_session, api_backend=self._api_backend)
|
||||
yield GeminiLlmConnection(
|
||||
live_session,
|
||||
api_backend=self._api_backend,
|
||||
model_version=llm_request.model,
|
||||
)
|
||||
|
||||
async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None:
|
||||
"""Adapt the google computer use predefined functions to the adk computer use toolset."""
|
||||
|
||||
@@ -19,6 +19,8 @@ from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
MODEL_VERSION = 'gemini-2.5-pro'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_session():
|
||||
@@ -30,7 +32,9 @@ def mock_gemini_session():
|
||||
def gemini_connection(mock_gemini_session):
|
||||
"""GeminiLlmConnection instance with mocked session."""
|
||||
return GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
mock_gemini_session,
|
||||
api_backend=GoogleLLMVariant.VERTEX_AI,
|
||||
model_version=MODEL_VERSION,
|
||||
)
|
||||
|
||||
|
||||
@@ -38,7 +42,9 @@ def gemini_connection(mock_gemini_session):
|
||||
def gemini_api_connection(mock_gemini_session):
|
||||
"""GeminiLlmConnection instance with mocked session for Gemini API."""
|
||||
return GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API
|
||||
mock_gemini_session,
|
||||
api_backend=GoogleLLMVariant.GEMINI_API,
|
||||
model_version=MODEL_VERSION,
|
||||
)
|
||||
|
||||
|
||||
@@ -215,6 +221,7 @@ async def test_receive_usage_metadata_and_server_content(
|
||||
|
||||
usage_response = next((r for r in responses if r.usage_metadata), None)
|
||||
assert usage_response is not None
|
||||
assert usage_response.model_version == MODEL_VERSION
|
||||
content_response = next((r for r in responses if r.content), None)
|
||||
assert content_response is not None
|
||||
|
||||
|
||||
@@ -705,20 +705,27 @@ async def test_connect_without_custom_headers(gemini_llm, llm_request):
|
||||
|
||||
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
|
||||
|
||||
async with gemini_llm.connect(llm_request) as connection:
|
||||
# Verify that the connect method was called with the right config
|
||||
mock_live_client.aio.live.connect.assert_called_once()
|
||||
call_args = mock_live_client.aio.live.connect.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
with mock.patch(
|
||||
"google.adk.models.google_llm.GeminiLlmConnection"
|
||||
) as MockGeminiLlmConnection:
|
||||
async with gemini_llm.connect(llm_request) as connection:
|
||||
# Verify that the connect method was called with the right config
|
||||
mock_live_client.aio.live.connect.assert_called_once()
|
||||
call_args = mock_live_client.aio.live.connect.call_args
|
||||
config_arg = call_args.kwargs["config"]
|
||||
|
||||
# Verify that http_options remains None since no custom headers were provided
|
||||
assert config_arg.http_options is None
|
||||
# Verify that http_options remains None since no custom headers were provided
|
||||
assert config_arg.http_options is None
|
||||
|
||||
# Verify that system instruction and tools were still set
|
||||
assert config_arg.system_instruction is not None
|
||||
assert config_arg.tools == llm_request.config.tools
|
||||
# Verify that system instruction and tools were still set
|
||||
assert config_arg.system_instruction is not None
|
||||
assert config_arg.tools == llm_request.config.tools
|
||||
|
||||
assert isinstance(connection, GeminiLlmConnection)
|
||||
MockGeminiLlmConnection.assert_called_once_with(
|
||||
mock_live_session,
|
||||
api_backend=gemini_llm._api_backend,
|
||||
model_version=llm_request.model,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user