fix: Only filter out audio content when sending history

audio is transcribed thus no need to be sent, but other blob(e.g. image) should still be sent.

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 856422986
This commit is contained in:
Xiang (Sean) Zhou
2026-01-14 16:50:45 -08:00
committed by Copybara-Service
parent 89bed43f5e
commit 712b5a393d
3 changed files with 224 additions and 4 deletions
+12 -4
View File
@@ -20,6 +20,7 @@ from typing import Union
from google.genai import types
from ..utils.content_utils import filter_audio_parts
from ..utils.context_utils import Aclosing
from ..utils.variant_utils import GoogleLLMVariant
from .base_llm_connection import BaseLlmConnection
@@ -63,15 +64,22 @@ class GeminiLlmConnection(BaseLlmConnection):
# TODO: Remove this filter and translate unary contents to streaming
# contents properly.
# We ignore any audio from user during the agent transfer phase
# Filter out audio parts from history because:
# 1. audio has already been transcribed.
# 2. sending audio via connection.send or connection.send_live_content is
# not supported by LIVE API (session will be corrupted).
# This method is called when:
# 1. Agent transfer to a new agent
# 2. Establishing a new live connection with previous ADK session history
contents = [
content
filtered
for content in history
if content.parts and content.parts[0].text
if (filtered := filter_audio_parts(content)) is not None
]
logger.debug('Sending history to live connection: %s', contents)
if contents:
logger.debug('Sending history to live connection: %s', contents)
await self._gemini_session.send(
input=types.LiveClientContent(
turns=contents,
+38
View File
@@ -0,0 +1,38 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.genai import types
def is_audio_part(part: types.Part) -> bool:
return (
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type.startswith('audio/')
) or (
part.file_data
and part.file_data.mime_type
and part.file_data.mime_type.startswith('audio/')
)
def filter_audio_parts(content: types.Content) -> types.Content | None:
if not content.parts:
return None
filtered_parts = [part for part in content.parts if not is_audio_part(part)]
if not filtered_parts:
return None
return types.Content(role=content.role, parts=filtered_parts)
@@ -600,3 +600,177 @@ async def test_receive_handles_output_transcription_fragments(
assert responses[2].output_transcription.text == 'How can I help?'
assert responses[2].output_transcription.finished is True
assert responses[2].partial is False
@pytest.mark.asyncio
@pytest.mark.parametrize(
'audio_part',
[
types.Part(
inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm')
),
types.Part(
file_data=types.FileData(
file_uri='artifact://app/user/session/_adk_live/audio.pcm#1',
mime_type='audio/pcm',
)
),
],
)
async def test_send_history_filters_audio(mock_gemini_session, audio_part):
"""Test that audio parts (inline or file_data) are filtered out."""
connection = GeminiLlmConnection(
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
)
history = [
types.Content(
role='user',
parts=[audio_part],
),
types.Content(
role='model', parts=[types.Part.from_text(text='I heard you')]
),
]
await connection.send_history(history)
mock_gemini_session.send.assert_called_once()
call_args = mock_gemini_session.send.call_args[1]
sent_contents = call_args['input'].turns
# Only the model response should be sent (user audio filtered out)
assert len(sent_contents) == 1
assert sent_contents[0].role == 'model'
assert sent_contents[0].parts == [types.Part.from_text(text='I heard you')]
@pytest.mark.asyncio
async def test_send_history_keeps_image_data(mock_gemini_session):
"""Test that image data is NOT filtered out."""
connection = GeminiLlmConnection(
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
)
image_blob = types.Blob(data=b'\x89PNG\r\n', mime_type='image/png')
history = [
types.Content(
role='user',
parts=[types.Part(inline_data=image_blob)],
),
types.Content(
role='model', parts=[types.Part.from_text(text='Nice image!')]
),
]
await connection.send_history(history)
mock_gemini_session.send.assert_called_once()
call_args = mock_gemini_session.send.call_args[1]
sent_contents = call_args['input'].turns
# Both contents should be sent (image is not filtered)
assert len(sent_contents) == 2
assert sent_contents[0].parts[0].inline_data == image_blob
@pytest.mark.asyncio
async def test_send_history_mixed_content_filters_only_audio(
mock_gemini_session,
):
"""Test that mixed content keeps non-audio parts."""
connection = GeminiLlmConnection(
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
)
history = [
types.Content(
role='user',
parts=[
types.Part(
inline_data=types.Blob(
data=b'\x00\xFF', mime_type='audio/wav'
)
),
types.Part.from_text(text='transcribed text'),
],
),
]
await connection.send_history(history)
mock_gemini_session.send.assert_called_once()
call_args = mock_gemini_session.send.call_args[1]
sent_contents = call_args['input'].turns
# Content should be sent but only with the text part
assert len(sent_contents) == 1
assert len(sent_contents[0].parts) == 1
assert sent_contents[0].parts[0].text == 'transcribed text'
@pytest.mark.asyncio
async def test_send_history_all_audio_content_not_sent(mock_gemini_session):
"""Test that content with only audio parts is completely removed."""
connection = GeminiLlmConnection(
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
)
history = [
types.Content(
role='user',
parts=[
types.Part(
inline_data=types.Blob(
data=b'\x00\xFF', mime_type='audio/pcm'
)
),
types.Part(
file_data=types.FileData(
file_uri='artifact://audio.pcm#1',
mime_type='audio/wav',
)
),
],
),
]
await connection.send_history(history)
# No content should be sent since all parts are audio
mock_gemini_session.send.assert_not_called()
@pytest.mark.asyncio
async def test_send_history_empty_history_not_sent(mock_gemini_session):
"""Test that empty history does not call send."""
connection = GeminiLlmConnection(
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
)
await connection.send_history([])
mock_gemini_session.send.assert_not_called()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'audio_mime_type',
['audio/pcm', 'audio/wav', 'audio/mp3', 'audio/ogg'],
)
async def test_send_history_filters_various_audio_mime_types(
mock_gemini_session,
audio_mime_type,
):
"""Test that various audio mime types are all filtered."""
connection = GeminiLlmConnection(
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
)
history = [
types.Content(
role='user',
parts=[
types.Part(
inline_data=types.Blob(data=b'', mime_type=audio_mime_type)
)
],
),
]
await connection.send_history(history)
# No content should be sent since the only part is audio
mock_gemini_session.send.assert_not_called()