You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Only filter out audio content when sending history
audio is transcribed thus no need to be sent, but other blob(e.g. image) should still be sent. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 856422986
This commit is contained in:
committed by
Copybara-Service
parent
89bed43f5e
commit
712b5a393d
@@ -20,6 +20,7 @@ from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
|
||||
from ..utils.content_utils import filter_audio_parts
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.variant_utils import GoogleLLMVariant
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
@@ -63,15 +64,22 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
# TODO: Remove this filter and translate unary contents to streaming
|
||||
# contents properly.
|
||||
|
||||
# We ignore any audio from user during the agent transfer phase
|
||||
# Filter out audio parts from history because:
|
||||
# 1. audio has already been transcribed.
|
||||
# 2. sending audio via connection.send or connection.send_live_content is
|
||||
# not supported by LIVE API (session will be corrupted).
|
||||
# This method is called when:
|
||||
# 1. Agent transfer to a new agent
|
||||
# 2. Establishing a new live connection with previous ADK session history
|
||||
|
||||
contents = [
|
||||
content
|
||||
filtered
|
||||
for content in history
|
||||
if content.parts and content.parts[0].text
|
||||
if (filtered := filter_audio_parts(content)) is not None
|
||||
]
|
||||
logger.debug('Sending history to live connection: %s', contents)
|
||||
|
||||
if contents:
|
||||
logger.debug('Sending history to live connection: %s', contents)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=contents,
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def is_audio_part(part: types.Part) -> bool:
|
||||
return (
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
and part.inline_data.mime_type.startswith('audio/')
|
||||
) or (
|
||||
part.file_data
|
||||
and part.file_data.mime_type
|
||||
and part.file_data.mime_type.startswith('audio/')
|
||||
)
|
||||
|
||||
|
||||
def filter_audio_parts(content: types.Content) -> types.Content | None:
|
||||
if not content.parts:
|
||||
return None
|
||||
filtered_parts = [part for part in content.parts if not is_audio_part(part)]
|
||||
if not filtered_parts:
|
||||
return None
|
||||
return types.Content(role=content.role, parts=filtered_parts)
|
||||
@@ -600,3 +600,177 @@ async def test_receive_handles_output_transcription_fragments(
|
||||
assert responses[2].output_transcription.text == 'How can I help?'
|
||||
assert responses[2].output_transcription.finished is True
|
||||
assert responses[2].partial is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'audio_part',
|
||||
[
|
||||
types.Part(
|
||||
inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
|
||||
),
|
||||
types.Part(
|
||||
file_data=types.FileData(
|
||||
file_uri='artifact://app/user/session/_adk_live/audio.pcm#1',
|
||||
mime_type='audio/pcm',
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_send_history_filters_audio(mock_gemini_session, audio_part):
|
||||
"""Test that audio parts (inline or file_data) are filtered out."""
|
||||
connection = GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
history = [
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[audio_part],
|
||||
),
|
||||
types.Content(
|
||||
role='model', parts=[types.Part.from_text(text='I heard you')]
|
||||
),
|
||||
]
|
||||
|
||||
await connection.send_history(history)
|
||||
|
||||
mock_gemini_session.send.assert_called_once()
|
||||
call_args = mock_gemini_session.send.call_args[1]
|
||||
sent_contents = call_args['input'].turns
|
||||
# Only the model response should be sent (user audio filtered out)
|
||||
assert len(sent_contents) == 1
|
||||
assert sent_contents[0].role == 'model'
|
||||
assert sent_contents[0].parts == [types.Part.from_text(text='I heard you')]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_history_keeps_image_data(mock_gemini_session):
|
||||
"""Test that image data is NOT filtered out."""
|
||||
connection = GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
image_blob = types.Blob(data=b'\x89PNG\r\n', mime_type='image/png')
|
||||
history = [
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(inline_data=image_blob)],
|
||||
),
|
||||
types.Content(
|
||||
role='model', parts=[types.Part.from_text(text='Nice image!')]
|
||||
),
|
||||
]
|
||||
|
||||
await connection.send_history(history)
|
||||
|
||||
mock_gemini_session.send.assert_called_once()
|
||||
call_args = mock_gemini_session.send.call_args[1]
|
||||
sent_contents = call_args['input'].turns
|
||||
# Both contents should be sent (image is not filtered)
|
||||
assert len(sent_contents) == 2
|
||||
assert sent_contents[0].parts[0].inline_data == image_blob
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_history_mixed_content_filters_only_audio(
|
||||
mock_gemini_session,
|
||||
):
|
||||
"""Test that mixed content keeps non-audio parts."""
|
||||
connection = GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
history = [
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b'\x00\xFF', mime_type='audio/wav'
|
||||
)
|
||||
),
|
||||
types.Part.from_text(text='transcribed text'),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
await connection.send_history(history)
|
||||
|
||||
mock_gemini_session.send.assert_called_once()
|
||||
call_args = mock_gemini_session.send.call_args[1]
|
||||
sent_contents = call_args['input'].turns
|
||||
# Content should be sent but only with the text part
|
||||
assert len(sent_contents) == 1
|
||||
assert len(sent_contents[0].parts) == 1
|
||||
assert sent_contents[0].parts[0].text == 'transcribed text'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_history_all_audio_content_not_sent(mock_gemini_session):
|
||||
"""Test that content with only audio parts is completely removed."""
|
||||
connection = GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
history = [
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
inline_data=types.Blob(
|
||||
data=b'\x00\xFF', mime_type='audio/pcm'
|
||||
)
|
||||
),
|
||||
types.Part(
|
||||
file_data=types.FileData(
|
||||
file_uri='artifact://audio.pcm#1',
|
||||
mime_type='audio/wav',
|
||||
)
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
await connection.send_history(history)
|
||||
|
||||
# No content should be sent since all parts are audio
|
||||
mock_gemini_session.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_history_empty_history_not_sent(mock_gemini_session):
|
||||
"""Test that empty history does not call send."""
|
||||
connection = GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
|
||||
await connection.send_history([])
|
||||
|
||||
mock_gemini_session.send.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'audio_mime_type',
|
||||
['audio/pcm', 'audio/wav', 'audio/mp3', 'audio/ogg'],
|
||||
)
|
||||
async def test_send_history_filters_various_audio_mime_types(
|
||||
mock_gemini_session,
|
||||
audio_mime_type,
|
||||
):
|
||||
"""Test that various audio mime types are all filtered."""
|
||||
connection = GeminiLlmConnection(
|
||||
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
history = [
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
inline_data=types.Blob(data=b'', mime_type=audio_mime_type)
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
await connection.send_history(history)
|
||||
|
||||
# No content should be sent since the only part is audio
|
||||
mock_gemini_session.send.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user