diff --git a/contributing/samples/manual_ollama_test/__init__.py b/contributing/samples/manual_ollama_test/__init__.py new file mode 100644 index 00000000..c48963cd --- /dev/null +++ b/contributing/samples/manual_ollama_test/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/manual_ollama_test/agent.py b/contributing/samples/manual_ollama_test/agent.py new file mode 100644 index 00000000..e3d071b9 --- /dev/null +++ b/contributing/samples/manual_ollama_test/agent.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.lite_llm import LiteLlm + +ollama_model = LiteLlm(model="ollama_chat/qwen2.5:7b") + +hello_agent = LlmAgent( + name="hello_step", + instruction="Say hello to the user. Be concise.", + model=ollama_model, +) + +summarize_agent = LlmAgent( + name="summarize_step", + instruction="Summarize the previous assistant message in 5 words.", + model=ollama_model, +) + +root_agent = SequentialAgent( + name="ollama_seq_test", + description="Two-step sanity check for Ollama LiteLLM chat.", + sub_agents=[hello_agent, summarize_agent], +) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index dfbf15a7..fd74a5e3 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -15,6 +15,7 @@ from __future__ import annotations import base64 +import copy import json import logging import os @@ -559,6 +560,91 @@ async def _get_content( return content_objects +def _is_ollama_chat_provider( + model: Optional[str], custom_llm_provider: Optional[str] +) -> bool: + """Returns True when requests should be normalized for ollama_chat.""" + if custom_llm_provider and custom_llm_provider.lower() == "ollama_chat": + return True + if model and model.lower().startswith("ollama_chat"): + return True + return False + + +def _flatten_ollama_content( + content: OpenAIMessageContent | str | None, +) -> str | OpenAIMessageContent | None: + """Flattens multipart content to text for ollama_chat compatibility. + + Ollama's chat endpoint rejects arrays for `content`. We keep textual parts, + join them with newlines, and fall back to a JSON string for non-text content. + If both text and non-text parts are present, only the text parts are kept. + """ + if not isinstance(content, list): + return content + + text_parts = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text_value = block.get("text") + if text_value: + text_parts.append(text_value) + + if text_parts: + return _NEW_LINE.join(text_parts) + + try: + return json.dumps(content) + except TypeError: + return str(content) + + +def _normalize_ollama_chat_messages( + messages: list[Message], + *, + model: Optional[str] = None, + custom_llm_provider: Optional[str] = None, +) -> list[Message]: + """Normalizes message payloads for ollama_chat provider. + + The provider expects string content. Convert multipart content to text while + leaving other providers untouched. + """ + if not _is_ollama_chat_provider(model, custom_llm_provider): + return messages + + normalized_messages: list[Message] = [] + for message in messages: + if isinstance(message, dict): + message_copy = dict(message) + message_copy["content"] = _flatten_ollama_content( + message_copy.get("content") + ) + normalized_messages.append(message_copy) + continue + + message_copy = ( + message.model_copy() + if hasattr(message, "model_copy") + else copy.copy(message) + ) + if hasattr(message_copy, "content"): + flattened_content = _flatten_ollama_content( + getattr(message_copy, "content") + ) + try: + setattr(message_copy, "content", flattened_content) + except AttributeError as e: + logger.debug( + "Failed to set 'content' attribute on message of type %s: %s", + type(message_copy).__name__, + e, + ) + normalized_messages.append(message_copy) + + return normalized_messages + + def _build_tool_call_from_json_dict( candidate: Any, *, index: int ) -> Optional[ChatCompletionMessageToolCall]: @@ -1350,9 +1436,14 @@ class LiteLlm(BaseLlm): _append_fallback_user_content_if_missing(llm_request) logger.debug(_build_request_log(llm_request)) - model = llm_request.model or self.model + effective_model = llm_request.model or self.model messages, tools, response_format, generation_params = ( - await _get_completion_inputs(llm_request, model) + await _get_completion_inputs(llm_request, effective_model) + ) + normalized_messages = _normalize_ollama_chat_messages( + messages, + model=effective_model, + custom_llm_provider=self._additional_args.get("custom_llm_provider"), ) if "functions" in self._additional_args: @@ -1360,8 +1451,8 @@ class LiteLlm(BaseLlm): tools = None completion_args = { - "model": model, - "messages": messages, + "model": effective_model, + "messages": normalized_messages, "tools": tools, "response_format": response_format, } diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index d3d58d44..1043d415 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -13,6 +13,7 @@ # limitations under the Licens import json +from unittest.mock import ANY from unittest.mock import AsyncMock from unittest.mock import Mock import warnings @@ -1468,6 +1469,79 @@ async def test_generate_content_async_with_usage_metadata( mock_acompletion.assert_called_once() +@pytest.mark.asyncio +async def test_generate_content_async_ollama_chat_flattens_content( + mock_acompletion, mock_completion +): + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm( + model="ollama_chat/qwen2.5:7b", llm_client=llm_client + ) + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[ + types.Part.from_text(text="Describe this image."), + types.Part.from_bytes( + data=b"test_image", mime_type="image/png" + ), + ], + ) + ] + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once_with( + model="ollama_chat/qwen2.5:7b", + messages=ANY, + tools=ANY, + response_format=ANY, + ) + _, kwargs = mock_acompletion.call_args + message_content = kwargs["messages"][0]["content"] + assert isinstance(message_content, str) + assert "Describe this image." in message_content + + +@pytest.mark.asyncio +async def test_generate_content_async_custom_provider_flattens_content( + mock_acompletion, mock_completion +): + llm_client = MockLLMClient(mock_acompletion, mock_completion) + lite_llm_instance = LiteLlm( + model="qwen2.5:7b", + llm_client=llm_client, + custom_llm_provider="ollama_chat", + ) + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", + parts=[ + types.Part.from_text(text="Describe this image."), + types.Part.from_bytes( + data=b"test_image", mime_type="image/png" + ), + ], + ) + ] + ) + + async for _ in lite_llm_instance.generate_content_async(llm_request): + pass + + mock_acompletion.assert_called_once() + _, kwargs = mock_acompletion.call_args + assert kwargs["custom_llm_provider"] == "ollama_chat" + assert kwargs["model"] == "qwen2.5:7b" + message_content = kwargs["messages"][0]["content"] + assert isinstance(message_content, str) + assert "Describe this image." in message_content + + @pytest.mark.asyncio async def test_content_to_message_param_user_message(): content = types.Content(