fix: Refactor LiteLLM streaming response parsing for compatibility with LiteLLM 1.81+

Updates _model_response_to_chunk to better handle LiteLLM's streaming delta/message structure, including prioritizing delta when it contains meaningful content and preserving reasoning_content

Close #4225

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 873097502
This commit is contained in:
George Weale
2026-02-20 15:23:26 -08:00
committed by Copybara-Service
parent 6ea3696bcc
commit e8019b1b1b
3 changed files with 263 additions and 112 deletions
+2 -2
View File
@@ -124,7 +124,7 @@ test = [
"kubernetes>=29.0.0", # For GkeCodeExecutor
"langchain-community>=0.3.17",
"langgraph>=0.2.60, <0.4.8", # For LangGraphAgent
"litellm>=1.75.5, <1.80.17", # For LiteLLM tests
"litellm>=1.75.5, <2.0.0", # For LiteLLM tests
"llama-index-readers-file>=0.4.0", # For retrieval tests
"openai>=1.100.2", # For LiteLLM
"opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0",
@@ -156,7 +156,7 @@ extensions = [
"docker>=7.0.0", # For ContainerCodeExecutor
"kubernetes>=29.0.0", # For GkeCodeExecutor
"langgraph>=0.2.60, <0.4.8", # For LangGraphAgent
"litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
"litellm>=1.75.5, <2.0.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
"llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex.
"lxml>=5.3.0", # For load_web_page tool.
+158 -78
View File
@@ -70,7 +70,9 @@ if TYPE_CHECKING:
from litellm import Function
from litellm import Message
from litellm import ModelResponse
from litellm import ModelResponseStream
from litellm import OpenAIMessageContent
from litellm.types.utils import Delta
else:
litellm = None
acompletion = None
@@ -85,7 +87,9 @@ else:
Function = None
Message = None
ModelResponse = None
Delta = None
OpenAIMessageContent = None
ModelResponseStream = None
logger = logging.getLogger("google_adk." + __name__)
@@ -151,6 +155,7 @@ _LITELLM_GLOBAL_SYMBOLS = (
"Function",
"Message",
"ModelResponse",
"ModelResponseStream",
"OpenAIMessageContent",
"acompletion",
"completion",
@@ -382,15 +387,11 @@ def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]:
]
def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any:
"""Fetches the reasoning payload from a LiteLLM message or dict."""
def _extract_reasoning_value(message: Message | Delta | None) -> Any:
"""Fetches the reasoning payload from a LiteLLM message."""
if message is None:
return None
if hasattr(message, "reasoning_content"):
return getattr(message, "reasoning_content")
if isinstance(message, dict):
return message.get("reasoning_content")
return None
return message.get("reasoning_content")
class ChatCompletionFileUrlObject(TypedDict, total=False):
@@ -1264,7 +1265,7 @@ def _function_declaration_to_tool_param(
def _model_response_to_chunk(
response: ModelResponse,
response: ModelResponse | ModelResponseStream,
) -> Generator[
Tuple[
Optional[
@@ -1282,6 +1283,9 @@ def _model_response_to_chunk(
]:
"""Converts a litellm message to text, function or usage metadata chunk.
LiteLLM streaming chunks carry `delta`, while non-streaming chunks carry
`message`.
Args:
response: The response from the model.
@@ -1290,18 +1294,45 @@ def _model_response_to_chunk(
"""
_ensure_litellm_imported()
message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = response["choices"][0].get("finish_reason", None)
# check streaming delta
if message is None and response["choices"][0].get("delta", None):
message = response["choices"][0]["delta"]
def _has_meaningful_signal(message: Message | Delta | None) -> bool:
if message is None:
return False
return bool(
message.get("content")
or message.get("tool_calls")
or message.get("function_call")
or message.get("reasoning_content")
)
if isinstance(response, ModelResponseStream):
message_field = "delta"
elif isinstance(response, ModelResponse):
message_field = "message"
else:
raise TypeError(
"Unexpected response type from LiteLLM: %r" % (type(response),)
)
choices = response.get("choices")
if not choices:
yield None, None
else:
choice = choices[0]
finish_reason = choice.get("finish_reason")
if message_field == "delta":
message = choice.get("delta")
else:
message = choice.get("message")
if message is not None and not _has_meaningful_signal(message):
message = None
message_content: Optional[OpenAIMessageContent] = None
tool_calls: list[ChatCompletionMessageToolCall] = []
reasoning_parts: List[types.Part] = []
if message is not None:
# Both Delta and Message support dict-like .get() access
(
message_content,
tool_calls,
@@ -1318,39 +1349,46 @@ def _model_response_to_chunk(
if tool_calls:
for idx, tool_call in enumerate(tool_calls):
# aggregate tool_call
if tool_call.type == "function":
func_name = tool_call.function.name
func_args = tool_call.function.arguments
func_index = getattr(tool_call, "index", idx)
# LiteLLM tool call objects support dict-like .get() access
if tool_call.get("type") == "function":
function_obj = tool_call.get("function")
if not function_obj:
continue
func_name = function_obj.get("name")
func_args = function_obj.get("arguments")
func_index = tool_call.get("index", idx)
tool_call_id = tool_call.get("id")
# Ignore empty chunks that don't carry any information.
if not func_name and not func_args:
continue
yield FunctionChunk(
id=tool_call.id,
id=tool_call_id,
name=func_name,
args=func_args,
index=func_index,
), finish_reason
if finish_reason and not (message_content or tool_calls):
if finish_reason and not (message_content or tool_calls or reasoning_parts):
yield None, finish_reason
if not message:
yield None, None
# Ideally usage would be expected with the last ModelResponseStream with a
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]),
), None
usage = response.get("usage")
if usage:
try:
yield UsageMetadataChunk(
prompt_tokens=usage.get("prompt_tokens", 0) or 0,
completion_tokens=usage.get("completion_tokens", 0) or 0,
total_tokens=usage.get("total_tokens", 0) or 0,
cached_prompt_tokens=_extract_cached_prompt_tokens(usage),
), None
except AttributeError as e:
raise TypeError(
"Unexpected LiteLLM usage type: %r" % (type(usage),)
) from e
def _model_response_to_generate_content_response(
@@ -1902,6 +1940,57 @@ class LiteLlm(BaseLlm):
aggregated_llm_response_with_tool_call = None
usage_metadata = None
fallback_index = 0
def _finalize_tool_call_response(
*, model_version: str, finish_reason: str
) -> LlmResponse:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=text,
tool_calls=tool_calls,
),
model_version=model_version,
thought_parts=list(reasoning_parts) if reasoning_parts else None,
)
llm_response.finish_reason = _map_finish_reason(finish_reason)
return llm_response
def _finalize_text_response(
*, model_version: str, finish_reason: str
) -> LlmResponse:
message_content = text if text else None
llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=message_content,
),
model_version=model_version,
thought_parts=list(reasoning_parts) if reasoning_parts else None,
)
llm_response.finish_reason = _map_finish_reason(finish_reason)
return llm_response
def _reset_stream_buffers() -> None:
nonlocal text, reasoning_parts
text = ""
reasoning_parts = []
function_calls.clear()
async for part in await self.llm_client.acompletion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
@@ -1951,58 +2040,49 @@ class LiteLlm(BaseLlm):
cached_content_token_count=chunk.cached_prompt_tokens,
)
if (
finish_reason == "tool_calls" or finish_reason == "stop"
) and function_calls:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
# LiteLLM 1.81+ can set finish_reason="stop" on partial chunks. Only
# finalize tool calls on an explicit tool_calls finish_reason, or on a
# stop-only chunk (no content/tool deltas).
if function_calls and (
finish_reason == "tool_calls"
or (finish_reason == "stop" and chunk is None)
):
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=text,
tool_calls=tool_calls,
),
_finalize_tool_call_response(
model_version=part.model,
thought_parts=list(reasoning_parts)
if reasoning_parts
else None,
finish_reason=finish_reason,
)
)
aggregated_llm_response_with_tool_call.finish_reason = (
_map_finish_reason(finish_reason)
)
text = ""
reasoning_parts = []
function_calls.clear()
elif finish_reason == "stop" and (text or reasoning_parts):
message_content = text if text else None
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant", content=message_content
),
_reset_stream_buffers()
elif (
finish_reason == "stop"
and (text or reasoning_parts)
and chunk is None
and not function_calls
):
# Only aggregate text response when we have a true stop signal
# chunk is None means no content in this chunk, just finish signal.
# LiteLLM 1.81+ sets finish_reason="stop" on partial chunks with
# content.
aggregated_llm_response = _finalize_text_response(
model_version=part.model,
thought_parts=list(reasoning_parts)
if reasoning_parts
else None,
finish_reason=finish_reason,
)
aggregated_llm_response.finish_reason = _map_finish_reason(
finish_reason
)
text = ""
reasoning_parts = []
_reset_stream_buffers()
if function_calls and not aggregated_llm_response_with_tool_call:
aggregated_llm_response_with_tool_call = _finalize_tool_call_response(
model_version=part.model,
finish_reason="tool_calls",
)
_reset_stream_buffers()
if (text or reasoning_parts) and not aggregated_llm_response:
aggregated_llm_response = _finalize_text_response(
model_version=part.model,
finish_reason="stop",
)
_reset_stream_buffers()
# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
+103 -32
View File
@@ -45,6 +45,7 @@ from google.adk.models.lite_llm import _to_litellm_role
from google.adk.models.lite_llm import FunctionChunk
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import ReasoningChunk
from google.adk.models.lite_llm import TextChunk
from google.adk.models.lite_llm import UsageMetadataChunk
from google.adk.models.llm_request import LlmRequest
@@ -57,6 +58,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Choices
from litellm.types.utils import Delta
from litellm.types.utils import ModelResponse
from litellm.types.utils import ModelResponseStream
from litellm.types.utils import StreamingChoices
from pydantic import BaseModel
from pydantic import Field
@@ -129,7 +131,7 @@ FILE_BYTES_TEST_CASES = [
]
STREAMING_MODEL_RESPONSE = [
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -141,7 +143,7 @@ STREAMING_MODEL_RESPONSE = [
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -153,7 +155,7 @@ STREAMING_MODEL_RESPONSE = [
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -165,7 +167,7 @@ STREAMING_MODEL_RESPONSE = [
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -187,7 +189,7 @@ STREAMING_MODEL_RESPONSE = [
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -209,7 +211,7 @@ STREAMING_MODEL_RESPONSE = [
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -532,7 +534,7 @@ def test_schema_to_dict_filters_none_enum_values():
MULTIPLE_FUNCTION_CALLS_STREAM = [
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -553,7 +555,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -574,7 +576,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -595,7 +597,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -616,7 +618,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason="tool_calls",
@@ -627,7 +629,7 @@ MULTIPLE_FUNCTION_CALLS_STREAM = [
STREAM_WITH_EMPTY_CHUNK = [
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -648,7 +650,7 @@ STREAM_WITH_EMPTY_CHUNK = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -670,7 +672,7 @@ STREAM_WITH_EMPTY_CHUNK = [
]
),
# This is the problematic empty chunk that should be ignored.
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -691,7 +693,7 @@ STREAM_WITH_EMPTY_CHUNK = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[StreamingChoices(finish_reason="tool_calls", delta=Delta())]
),
]
@@ -727,7 +729,7 @@ def mock_response():
# indices all 0
# finish_reason stop instead of tool_calls
NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -748,7 +750,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -769,7 +771,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -790,7 +792,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -811,7 +813,7 @@ NON_COMPLIANT_MULTIPLE_FUNCTION_CALLS_STREAM = [
)
]
),
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason="stop",
@@ -2707,7 +2709,7 @@ def test_to_litellm_role():
"stop",
),
(
ModelResponse(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
@@ -2729,10 +2731,10 @@ def test_to_litellm_role():
]
),
[FunctionChunk(id="1", name="test_function", args='{"key": "va')],
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
None,
# LiteLLM 1.81+ defaults finish_reason to "stop" for partial chunks,
# older versions return None. Both are valid for streaming chunks.
(None, "stop"),
),
(
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
@@ -2813,6 +2815,38 @@ def test_to_litellm_role():
),
"tool_calls",
),
(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(role="assistant", content="Hello"),
)
]
),
[TextChunk(text="Hello")],
None,
(None, "stop"),
),
(
ModelResponseStream(
choices=[
StreamingChoices(
finish_reason="stop",
delta=Delta(
role="assistant", reasoning_content="thinking..."
),
)
]
),
[
ReasoningChunk(
parts=[types.Part(text="thinking...", thought=True)]
)
],
None,
"stop",
),
],
)
def test_model_response_to_chunk(
@@ -2836,7 +2870,10 @@ def test_model_response_to_chunk(
else:
assert isinstance(chunk, type(expected_chunk))
assert chunk == expected_chunk
assert finished == expected_finished
if isinstance(expected_finished, tuple):
assert finished in expected_finished
else:
assert finished == expected_finished
if expected_usage_chunk is None:
assert usage_chunk is None
@@ -2845,6 +2882,38 @@ def test_model_response_to_chunk(
assert usage_chunk == expected_usage_chunk
def test_model_response_to_chunk_does_not_mutate_delta_object():
"""Verify that _model_response_to_chunk doesn't mutate the Delta object.
In real streaming responses, LiteLLM's StreamingChoices only has 'delta'
(message is explicitly popped in StreamingChoices constructor). The delta
object itself carries reasoning_content when present.
"""
delta = Delta(
role="assistant", content="Hello", reasoning_content="thinking..."
)
response = ModelResponseStream(
choices=[StreamingChoices(delta=delta, finish_reason=None)]
)
chunks = [chunk for chunk, _ in _model_response_to_chunk(response) if chunk]
assert (
ReasoningChunk(parts=[types.Part(text="thinking...", thought=True)])
in chunks
)
assert TextChunk(text="Hello") in chunks
# Verify we don't accidentally mutate the original delta object.
assert delta.content == "Hello"
assert delta.reasoning_content == "thinking..."
def test_model_response_to_chunk_rejects_dict_response():
with pytest.raises(TypeError):
list(_model_response_to_chunk({"choices": []}))
@pytest.mark.asyncio
async def test_acompletion_additional_args(mock_acompletion, mock_client):
lite_llm_instance = LiteLlm(
@@ -3056,7 +3125,7 @@ async def test_generate_content_async_stream_sets_finish_reason(
mock_completion, lite_llm_instance
):
mock_completion.return_value = iter([
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -3065,7 +3134,7 @@ async def test_generate_content_async_stream_sets_finish_reason(
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[
StreamingChoices(
@@ -3074,7 +3143,7 @@ async def test_generate_content_async_stream_sets_finish_reason(
)
],
),
ModelResponse(
ModelResponseStream(
model="test_model",
choices=[StreamingChoices(finish_reason="stop", delta=Delta())],
),
@@ -3107,7 +3176,7 @@ async def test_generate_content_async_stream_with_usage_metadata(
streaming_model_response_with_usage_metadata = [
*STREAMING_MODEL_RESPONSE,
ModelResponse(
ModelResponseStream(
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
@@ -3176,7 +3245,7 @@ async def test_generate_content_async_stream_with_usage_metadata(
"""Tests that cached prompt tokens are propagated in streaming mode."""
streaming_model_response_with_usage_metadata = [
*STREAMING_MODEL_RESPONSE,
ModelResponse(
ModelResponseStream(
usage={
"prompt_tokens": 10,
"completion_tokens": 5,
@@ -3657,7 +3726,7 @@ async def test_finish_reason_propagation(
async def test_finish_reason_unknown_maps_to_other(
mock_acompletion, lite_llm_instance
):
"""Test that unknown finish_reason values map to FinishReason.OTHER."""
"""Test that unmapped finish_reason values map to FinishReason.OTHER."""
mock_response = ModelResponse(
choices=[
Choices(
@@ -3665,7 +3734,9 @@ async def test_finish_reason_unknown_maps_to_other(
role="assistant",
content="Test response",
),
finish_reason="unknown_reason_type",
# LiteLLM validates finish_reason to a known set. Use a value that
# LiteLLM accepts but ADK does not explicitly map.
finish_reason="eos",
)
]
)