From 712b5a393d44e7b5ce35fc459da98361bae4bb16 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 14 Jan 2026 16:50:45 -0800 Subject: [PATCH] fix: Only filter out audio content when sending history audio is transcribed thus no need to be sent, but other blob(e.g. image) should still be sent. Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 856422986 --- .../adk/models/gemini_llm_connection.py | 16 +- src/google/adk/utils/content_utils.py | 38 ++++ .../models/test_gemini_llm_connection.py | 174 ++++++++++++++++++ 3 files changed, 224 insertions(+), 4 deletions(-) create mode 100644 src/google/adk/utils/content_utils.py diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 327157e2..158a5cab 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -20,6 +20,7 @@ from typing import Union from google.genai import types +from ..utils.content_utils import filter_audio_parts from ..utils.context_utils import Aclosing from ..utils.variant_utils import GoogleLLMVariant from .base_llm_connection import BaseLlmConnection @@ -63,15 +64,22 @@ class GeminiLlmConnection(BaseLlmConnection): # TODO: Remove this filter and translate unary contents to streaming # contents properly. - # We ignore any audio from user during the agent transfer phase + # Filter out audio parts from history because: + # 1. audio has already been transcribed. + # 2. sending audio via connection.send or connection.send_live_content is + # not supported by LIVE API (session will be corrupted). + # This method is called when: + # 1. Agent transfer to a new agent + # 2. Establishing a new live connection with previous ADK session history + contents = [ - content + filtered for content in history - if content.parts and content.parts[0].text + if (filtered := filter_audio_parts(content)) is not None ] - logger.debug('Sending history to live connection: %s', contents) if contents: + logger.debug('Sending history to live connection: %s', contents) await self._gemini_session.send( input=types.LiveClientContent( turns=contents, diff --git a/src/google/adk/utils/content_utils.py b/src/google/adk/utils/content_utils.py new file mode 100644 index 00000000..379c31ec --- /dev/null +++ b/src/google/adk/utils/content_utils.py @@ -0,0 +1,38 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.genai import types + + +def is_audio_part(part: types.Part) -> bool: + return ( + part.inline_data + and part.inline_data.mime_type + and part.inline_data.mime_type.startswith('audio/') + ) or ( + part.file_data + and part.file_data.mime_type + and part.file_data.mime_type.startswith('audio/') + ) + + +def filter_audio_parts(content: types.Content) -> types.Content | None: + if not content.parts: + return None + filtered_parts = [part for part in content.parts if not is_audio_part(part)] + if not filtered_parts: + return None + return types.Content(role=content.role, parts=filtered_parts) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index de8f4f9d..ac65b2ac 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -600,3 +600,177 @@ async def test_receive_handles_output_transcription_fragments( assert responses[2].output_transcription.text == 'How can I help?' assert responses[2].output_transcription.finished is True assert responses[2].partial is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'audio_part', + [ + types.Part( + inline_data=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm') + ), + types.Part( + file_data=types.FileData( + file_uri='artifact://app/user/session/_adk_live/audio.pcm#1', + mime_type='audio/pcm', + ) + ), + ], +) +async def test_send_history_filters_audio(mock_gemini_session, audio_part): + """Test that audio parts (inline or file_data) are filtered out.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[audio_part], + ), + types.Content( + role='model', parts=[types.Part.from_text(text='I heard you')] + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Only the model response should be sent (user audio filtered out) + assert len(sent_contents) == 1 + assert sent_contents[0].role == 'model' + assert sent_contents[0].parts == [types.Part.from_text(text='I heard you')] + + +@pytest.mark.asyncio +async def test_send_history_keeps_image_data(mock_gemini_session): + """Test that image data is NOT filtered out.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + image_blob = types.Blob(data=b'\x89PNG\r\n', mime_type='image/png') + history = [ + types.Content( + role='user', + parts=[types.Part(inline_data=image_blob)], + ), + types.Content( + role='model', parts=[types.Part.from_text(text='Nice image!')] + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Both contents should be sent (image is not filtered) + assert len(sent_contents) == 2 + assert sent_contents[0].parts[0].inline_data == image_blob + + +@pytest.mark.asyncio +async def test_send_history_mixed_content_filters_only_audio( + mock_gemini_session, +): + """Test that mixed content keeps non-audio parts.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00\xFF', mime_type='audio/wav' + ) + ), + types.Part.from_text(text='transcribed text'), + ], + ), + ] + + await connection.send_history(history) + + mock_gemini_session.send.assert_called_once() + call_args = mock_gemini_session.send.call_args[1] + sent_contents = call_args['input'].turns + # Content should be sent but only with the text part + assert len(sent_contents) == 1 + assert len(sent_contents[0].parts) == 1 + assert sent_contents[0].parts[0].text == 'transcribed text' + + +@pytest.mark.asyncio +async def test_send_history_all_audio_content_not_sent(mock_gemini_session): + """Test that content with only audio parts is completely removed.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00\xFF', mime_type='audio/pcm' + ) + ), + types.Part( + file_data=types.FileData( + file_uri='artifact://audio.pcm#1', + mime_type='audio/wav', + ) + ), + ], + ), + ] + + await connection.send_history(history) + + # No content should be sent since all parts are audio + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_history_empty_history_not_sent(mock_gemini_session): + """Test that empty history does not call send.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + + await connection.send_history([]) + + mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'audio_mime_type', + ['audio/pcm', 'audio/wav', 'audio/mp3', 'audio/ogg'], +) +async def test_send_history_filters_various_audio_mime_types( + mock_gemini_session, + audio_mime_type, +): + """Test that various audio mime types are all filtered.""" + connection = GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + history = [ + types.Content( + role='user', + parts=[ + types.Part( + inline_data=types.Blob(data=b'', mime_type=audio_mime_type) + ) + ], + ), + ] + + await connection.send_history(history) + + # No content should be sent since the only part is audio + mock_gemini_session.send.assert_not_called()