fix: Normalize multipart content for LiteLLM's ollama_chat provider

LiteLLM's `ollama_chat` provider does not accept array-based content in messages. This change flattens multipart content by joining text parts or JSON-serializing non-text parts before sending the request to the LiteLLM completion API. This ensures compatibility with Ollama's chat endpoint.

Close #3727

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 843382361
This commit is contained in:
George Weale
2025-12-11 14:46:08 -08:00
committed by Copybara-Service
parent df8684734b
commit 055dfc7974
4 changed files with 223 additions and 4 deletions
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,39 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.models.lite_llm import LiteLlm
ollama_model = LiteLlm(model="ollama_chat/qwen2.5:7b")
hello_agent = LlmAgent(
name="hello_step",
instruction="Say hello to the user. Be concise.",
model=ollama_model,
)
summarize_agent = LlmAgent(
name="summarize_step",
instruction="Summarize the previous assistant message in 5 words.",
model=ollama_model,
)
root_agent = SequentialAgent(
name="ollama_seq_test",
description="Two-step sanity check for Ollama LiteLLM chat.",
sub_agents=[hello_agent, summarize_agent],
)
+95 -4
View File
@@ -15,6 +15,7 @@
from __future__ import annotations
import base64
import copy
import json
import logging
import os
@@ -559,6 +560,91 @@ async def _get_content(
return content_objects
def _is_ollama_chat_provider(
model: Optional[str], custom_llm_provider: Optional[str]
) -> bool:
"""Returns True when requests should be normalized for ollama_chat."""
if custom_llm_provider and custom_llm_provider.lower() == "ollama_chat":
return True
if model and model.lower().startswith("ollama_chat"):
return True
return False
def _flatten_ollama_content(
content: OpenAIMessageContent | str | None,
) -> str | OpenAIMessageContent | None:
"""Flattens multipart content to text for ollama_chat compatibility.
Ollama's chat endpoint rejects arrays for `content`. We keep textual parts,
join them with newlines, and fall back to a JSON string for non-text content.
If both text and non-text parts are present, only the text parts are kept.
"""
if not isinstance(content, list):
return content
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_value = block.get("text")
if text_value:
text_parts.append(text_value)
if text_parts:
return _NEW_LINE.join(text_parts)
try:
return json.dumps(content)
except TypeError:
return str(content)
def _normalize_ollama_chat_messages(
messages: list[Message],
*,
model: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
) -> list[Message]:
"""Normalizes message payloads for ollama_chat provider.
The provider expects string content. Convert multipart content to text while
leaving other providers untouched.
"""
if not _is_ollama_chat_provider(model, custom_llm_provider):
return messages
normalized_messages: list[Message] = []
for message in messages:
if isinstance(message, dict):
message_copy = dict(message)
message_copy["content"] = _flatten_ollama_content(
message_copy.get("content")
)
normalized_messages.append(message_copy)
continue
message_copy = (
message.model_copy()
if hasattr(message, "model_copy")
else copy.copy(message)
)
if hasattr(message_copy, "content"):
flattened_content = _flatten_ollama_content(
getattr(message_copy, "content")
)
try:
setattr(message_copy, "content", flattened_content)
except AttributeError as e:
logger.debug(
"Failed to set 'content' attribute on message of type %s: %s",
type(message_copy).__name__,
e,
)
normalized_messages.append(message_copy)
return normalized_messages
def _build_tool_call_from_json_dict(
candidate: Any, *, index: int
) -> Optional[ChatCompletionMessageToolCall]:
@@ -1350,9 +1436,14 @@ class LiteLlm(BaseLlm):
_append_fallback_user_content_if_missing(llm_request)
logger.debug(_build_request_log(llm_request))
model = llm_request.model or self.model
effective_model = llm_request.model or self.model
messages, tools, response_format, generation_params = (
await _get_completion_inputs(llm_request, model)
await _get_completion_inputs(llm_request, effective_model)
)
normalized_messages = _normalize_ollama_chat_messages(
messages,
model=effective_model,
custom_llm_provider=self._additional_args.get("custom_llm_provider"),
)
if "functions" in self._additional_args:
@@ -1360,8 +1451,8 @@ class LiteLlm(BaseLlm):
tools = None
completion_args = {
"model": model,
"messages": messages,
"model": effective_model,
"messages": normalized_messages,
"tools": tools,
"response_format": response_format,
}
+74
View File
@@ -13,6 +13,7 @@
# limitations under the Licens
import json
from unittest.mock import ANY
from unittest.mock import AsyncMock
from unittest.mock import Mock
import warnings
@@ -1468,6 +1469,79 @@ async def test_generate_content_async_with_usage_metadata(
mock_acompletion.assert_called_once()
@pytest.mark.asyncio
async def test_generate_content_async_ollama_chat_flattens_content(
mock_acompletion, mock_completion
):
llm_client = MockLLMClient(mock_acompletion, mock_completion)
lite_llm_instance = LiteLlm(
model="ollama_chat/qwen2.5:7b", llm_client=llm_client
)
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[
types.Part.from_text(text="Describe this image."),
types.Part.from_bytes(
data=b"test_image", mime_type="image/png"
),
],
)
]
)
async for _ in lite_llm_instance.generate_content_async(llm_request):
pass
mock_acompletion.assert_called_once_with(
model="ollama_chat/qwen2.5:7b",
messages=ANY,
tools=ANY,
response_format=ANY,
)
_, kwargs = mock_acompletion.call_args
message_content = kwargs["messages"][0]["content"]
assert isinstance(message_content, str)
assert "Describe this image." in message_content
@pytest.mark.asyncio
async def test_generate_content_async_custom_provider_flattens_content(
mock_acompletion, mock_completion
):
llm_client = MockLLMClient(mock_acompletion, mock_completion)
lite_llm_instance = LiteLlm(
model="qwen2.5:7b",
llm_client=llm_client,
custom_llm_provider="ollama_chat",
)
llm_request = LlmRequest(
contents=[
types.Content(
role="user",
parts=[
types.Part.from_text(text="Describe this image."),
types.Part.from_bytes(
data=b"test_image", mime_type="image/png"
),
],
)
]
)
async for _ in lite_llm_instance.generate_content_async(llm_request):
pass
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["custom_llm_provider"] == "ollama_chat"
assert kwargs["model"] == "qwen2.5:7b"
message_content = kwargs["messages"][0]["content"]
assert isinstance(message_content, str)
assert "Describe this image." in message_content
@pytest.mark.asyncio
async def test_content_to_message_param_user_message():
content = types.Content(