You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Normalize multipart content for LiteLLM's ollama_chat provider
LiteLLM's `ollama_chat` provider does not accept array-based content in messages. This change flattens multipart content by joining text parts or JSON-serializing non-text parts before sending the request to the LiteLLM completion API. This ensures compatibility with Ollama's chat endpoint. Close #3727 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 843382361
This commit is contained in:
committed by
Copybara-Service
parent
df8684734b
commit
055dfc7974
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
|
||||
ollama_model = LiteLlm(model="ollama_chat/qwen2.5:7b")
|
||||
|
||||
hello_agent = LlmAgent(
|
||||
name="hello_step",
|
||||
instruction="Say hello to the user. Be concise.",
|
||||
model=ollama_model,
|
||||
)
|
||||
|
||||
summarize_agent = LlmAgent(
|
||||
name="summarize_step",
|
||||
instruction="Summarize the previous assistant message in 5 words.",
|
||||
model=ollama_model,
|
||||
)
|
||||
|
||||
root_agent = SequentialAgent(
|
||||
name="ollama_seq_test",
|
||||
description="Two-step sanity check for Ollama LiteLLM chat.",
|
||||
sub_agents=[hello_agent, summarize_agent],
|
||||
)
|
||||
@@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -559,6 +560,91 @@ async def _get_content(
|
||||
return content_objects
|
||||
|
||||
|
||||
def _is_ollama_chat_provider(
|
||||
model: Optional[str], custom_llm_provider: Optional[str]
|
||||
) -> bool:
|
||||
"""Returns True when requests should be normalized for ollama_chat."""
|
||||
if custom_llm_provider and custom_llm_provider.lower() == "ollama_chat":
|
||||
return True
|
||||
if model and model.lower().startswith("ollama_chat"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _flatten_ollama_content(
|
||||
content: OpenAIMessageContent | str | None,
|
||||
) -> str | OpenAIMessageContent | None:
|
||||
"""Flattens multipart content to text for ollama_chat compatibility.
|
||||
|
||||
Ollama's chat endpoint rejects arrays for `content`. We keep textual parts,
|
||||
join them with newlines, and fall back to a JSON string for non-text content.
|
||||
If both text and non-text parts are present, only the text parts are kept.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return content
|
||||
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_value = block.get("text")
|
||||
if text_value:
|
||||
text_parts.append(text_value)
|
||||
|
||||
if text_parts:
|
||||
return _NEW_LINE.join(text_parts)
|
||||
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _normalize_ollama_chat_messages(
|
||||
messages: list[Message],
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> list[Message]:
|
||||
"""Normalizes message payloads for ollama_chat provider.
|
||||
|
||||
The provider expects string content. Convert multipart content to text while
|
||||
leaving other providers untouched.
|
||||
"""
|
||||
if not _is_ollama_chat_provider(model, custom_llm_provider):
|
||||
return messages
|
||||
|
||||
normalized_messages: list[Message] = []
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
message_copy = dict(message)
|
||||
message_copy["content"] = _flatten_ollama_content(
|
||||
message_copy.get("content")
|
||||
)
|
||||
normalized_messages.append(message_copy)
|
||||
continue
|
||||
|
||||
message_copy = (
|
||||
message.model_copy()
|
||||
if hasattr(message, "model_copy")
|
||||
else copy.copy(message)
|
||||
)
|
||||
if hasattr(message_copy, "content"):
|
||||
flattened_content = _flatten_ollama_content(
|
||||
getattr(message_copy, "content")
|
||||
)
|
||||
try:
|
||||
setattr(message_copy, "content", flattened_content)
|
||||
except AttributeError as e:
|
||||
logger.debug(
|
||||
"Failed to set 'content' attribute on message of type %s: %s",
|
||||
type(message_copy).__name__,
|
||||
e,
|
||||
)
|
||||
normalized_messages.append(message_copy)
|
||||
|
||||
return normalized_messages
|
||||
|
||||
|
||||
def _build_tool_call_from_json_dict(
|
||||
candidate: Any, *, index: int
|
||||
) -> Optional[ChatCompletionMessageToolCall]:
|
||||
@@ -1350,9 +1436,14 @@ class LiteLlm(BaseLlm):
|
||||
_append_fallback_user_content_if_missing(llm_request)
|
||||
logger.debug(_build_request_log(llm_request))
|
||||
|
||||
model = llm_request.model or self.model
|
||||
effective_model = llm_request.model or self.model
|
||||
messages, tools, response_format, generation_params = (
|
||||
await _get_completion_inputs(llm_request, model)
|
||||
await _get_completion_inputs(llm_request, effective_model)
|
||||
)
|
||||
normalized_messages = _normalize_ollama_chat_messages(
|
||||
messages,
|
||||
model=effective_model,
|
||||
custom_llm_provider=self._additional_args.get("custom_llm_provider"),
|
||||
)
|
||||
|
||||
if "functions" in self._additional_args:
|
||||
@@ -1360,8 +1451,8 @@ class LiteLlm(BaseLlm):
|
||||
tools = None
|
||||
|
||||
completion_args = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"model": effective_model,
|
||||
"messages": normalized_messages,
|
||||
"tools": tools,
|
||||
"response_format": response_format,
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the Licens
|
||||
|
||||
import json
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import Mock
|
||||
import warnings
|
||||
@@ -1468,6 +1469,79 @@ async def test_generate_content_async_with_usage_metadata(
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_ollama_chat_flattens_content(
|
||||
mock_acompletion, mock_completion
|
||||
):
|
||||
llm_client = MockLLMClient(mock_acompletion, mock_completion)
|
||||
lite_llm_instance = LiteLlm(
|
||||
model="ollama_chat/qwen2.5:7b", llm_client=llm_client
|
||||
)
|
||||
llm_request = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(text="Describe this image."),
|
||||
types.Part.from_bytes(
|
||||
data=b"test_image", mime_type="image/png"
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async for _ in lite_llm_instance.generate_content_async(llm_request):
|
||||
pass
|
||||
|
||||
mock_acompletion.assert_called_once_with(
|
||||
model="ollama_chat/qwen2.5:7b",
|
||||
messages=ANY,
|
||||
tools=ANY,
|
||||
response_format=ANY,
|
||||
)
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
message_content = kwargs["messages"][0]["content"]
|
||||
assert isinstance(message_content, str)
|
||||
assert "Describe this image." in message_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_custom_provider_flattens_content(
|
||||
mock_acompletion, mock_completion
|
||||
):
|
||||
llm_client = MockLLMClient(mock_acompletion, mock_completion)
|
||||
lite_llm_instance = LiteLlm(
|
||||
model="qwen2.5:7b",
|
||||
llm_client=llm_client,
|
||||
custom_llm_provider="ollama_chat",
|
||||
)
|
||||
llm_request = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(text="Describe this image."),
|
||||
types.Part.from_bytes(
|
||||
data=b"test_image", mime_type="image/png"
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async for _ in lite_llm_instance.generate_content_async(llm_request):
|
||||
pass
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["custom_llm_provider"] == "ollama_chat"
|
||||
assert kwargs["model"] == "qwen2.5:7b"
|
||||
message_content = kwargs["messages"][0]["content"]
|
||||
assert isinstance(message_content, str)
|
||||
assert "Describe this image." in message_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_to_message_param_user_message():
|
||||
content = types.Content(
|
||||
|
||||
Reference in New Issue
Block a user