From aafa80bd85a49fb1c1a255ac797587cffd3fa567 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 16 Jun 2025 16:36:34 -0700 Subject: [PATCH] fix: stream in litellm + adk and add corresponding integration tests Fixes https://github.com/google/adk-python/issues/1368 PiperOrigin-RevId: 772218385 --- src/google/adk/models/lite_llm.py | 3 +- .../models/test_litellm_no_function.py | 109 +++++++++++++++++- .../models/test_litellm_with_function.py | 25 ++-- 3 files changed, 125 insertions(+), 12 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index ed54faec..dce5ed7c 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -739,11 +739,12 @@ class LiteLlm(BaseLlm): _message_to_generate_content_response( ChatCompletionAssistantMessage( role="assistant", - content="", + content=text, tool_calls=tool_calls, ) ) ) + text = "" function_calls.clear() elif finish_reason == "stop" and text: aggregated_llm_response = _message_to_generate_content_response( diff --git a/tests/integration/models/test_litellm_no_function.py b/tests/integration/models/test_litellm_no_function.py index e662384c..ff5d3bb8 100644 --- a/tests/integration/models/test_litellm_no_function.py +++ b/tests/integration/models/test_litellm_no_function.py @@ -20,12 +20,26 @@ from google.genai.types import Content from google.genai.types import Part import pytest -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """You are a helpful assistant.""" +def get_weather(city: str) -> str: + """Simulates a web search. Use it get information on weather. + + Args: + city: A string containing the location to get weather information for. + + Returns: + A string with the simulated weather information for the queried city. + """ + if "sf" in city.lower() or "san francisco" in city.lower(): + return "It's 70 degrees and foggy." + return "It's 80 degrees and sunny." + + @pytest.fixture def oss_llm(): return LiteLlm(model=_TEST_MODEL_NAME) @@ -44,6 +58,48 @@ def llm_request(): ) +@pytest.fixture +def llm_request_with_tools(): + return LlmRequest( + model=_TEST_MODEL_NAME, + contents=[ + Content( + role="user", + parts=[ + Part.from_text(text="What is the weather in San Francisco?") + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0.1, + response_modalities=[types.Modality.TEXT], + system_instruction=_SYSTEM_PROMPT, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get the weather in a given location", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description=( + "The city to get the weather for." + ), + ), + }, + required=["city"], + ), + ) + ] + ) + ], + ), + ) + + @pytest.mark.asyncio async def test_generate_content_async(oss_llm, llm_request): async for response in oss_llm.generate_content_async(llm_request): @@ -51,10 +107,8 @@ async def test_generate_content_async(oss_llm, llm_request): assert response.content.parts[0].text -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio -async def test_generate_content_async_stream(oss_llm, llm_request): +async def test_generate_content_async(oss_llm, llm_request): responses = [ resp async for resp in oss_llm.generate_content_async( @@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request): ] part = responses[0].content.parts[0] assert len(part.text) > 0 + + +@pytest.mark.asyncio +async def test_generate_content_async_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=False + ) + ] + function_call = responses[0].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_async_stream(oss_llm, llm_request): + responses = [ + resp + async for resp in oss_llm.generate_content_async(llm_request, stream=True) + ] + text = "" + for i in range(len(responses) - 1): + assert responses[i].partial is True + assert responses[i].content.parts[0].text + text += responses[i].content.parts[0].text + + # Last message should be accumulated text + assert responses[-1].content.parts[0].text == text + assert not responses[-1].partial + + +@pytest.mark.asyncio +async def test_generate_content_async_stream_with_tools( + oss_llm, llm_request_with_tools +): + responses = [ + resp + async for resp in oss_llm.generate_content_async( + llm_request_with_tools, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco" diff --git a/tests/integration/models/test_litellm_with_function.py b/tests/integration/models/test_litellm_with_function.py index a2ceb540..799c55e5 100644 --- a/tests/integration/models/test_litellm_with_function.py +++ b/tests/integration/models/test_litellm_with_function.py @@ -13,7 +13,6 @@ # limitations under the License. from google.adk.models import LlmRequest -from google.adk.models import LlmResponse from google.adk.models.lite_llm import LiteLlm from google.genai import types from google.genai.types import Content @@ -23,12 +22,11 @@ import pytest litellm.add_function_to_prompt = True -_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" - +_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas" _SYSTEM_PROMPT = """ You are a helpful assistant, and call tools optionally. -If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs. +If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs. """ @@ -40,7 +38,7 @@ _FUNCTIONS = [{ "properties": { "city": { "type": "string", - "description": "The city, e.g. San Francisco", + "description": "The city to get the weather for.", }, }, "required": ["city"], @@ -87,8 +85,6 @@ def llm_request(): ) -# Note that, this test disabled streaming because streaming is not supported -# properly in the current test model for now. @pytest.mark.asyncio async def test_generate_content_asyn_with_function( oss_llm_with_function, llm_request @@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function( function_call = responses[0].content.parts[0].function_call assert function_call.name == "get_weather" assert function_call.args["city"] == "San Francisco" + + +@pytest.mark.asyncio +async def test_generate_content_asyn_stream_with_function( + oss_llm_with_function, llm_request +): + responses = [ + resp + async for resp in oss_llm_with_function.generate_content_async( + llm_request, stream=True + ) + ] + function_call = responses[-1].content.parts[0].function_call + assert function_call.name == "get_weather" + assert function_call.args["city"] == "San Francisco"