fix: stream in litellm + adk and add corresponding integration tests

Fixes https://github.com/google/adk-python/issues/1368

PiperOrigin-RevId: 772218385
This commit is contained in:
Google Team Member
2025-06-16 16:36:34 -07:00
committed by Copybara-Service
parent 4bda245171
commit aafa80bd85
3 changed files with 125 additions and 12 deletions
+2 -1
View File
@@ -739,11 +739,12 @@ class LiteLlm(BaseLlm):
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
content=text,
tool_calls=tool_calls,
)
)
)
text = ""
function_calls.clear()
elif finish_reason == "stop" and text:
aggregated_llm_response = _message_to_generate_content_response(
@@ -20,12 +20,26 @@ from google.genai.types import Content
from google.genai.types import Part
import pytest
_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"
_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"
_SYSTEM_PROMPT = """You are a helpful assistant."""
def get_weather(city: str) -> str:
"""Simulates a web search. Use it get information on weather.
Args:
city: A string containing the location to get weather information for.
Returns:
A string with the simulated weather information for the queried city.
"""
if "sf" in city.lower() or "san francisco" in city.lower():
return "It's 70 degrees and foggy."
return "It's 80 degrees and sunny."
@pytest.fixture
def oss_llm():
return LiteLlm(model=_TEST_MODEL_NAME)
@@ -44,6 +58,48 @@ def llm_request():
)
@pytest.fixture
def llm_request_with_tools():
return LlmRequest(
model=_TEST_MODEL_NAME,
contents=[
Content(
role="user",
parts=[
Part.from_text(text="What is the weather in San Francisco?")
],
)
],
config=types.GenerateContentConfig(
temperature=0.1,
response_modalities=[types.Modality.TEXT],
system_instruction=_SYSTEM_PROMPT,
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="get_weather",
description="Get the weather in a given location",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"city": types.Schema(
type=types.Type.STRING,
description=(
"The city to get the weather for."
),
),
},
required=["city"],
),
)
]
)
],
),
)
@pytest.mark.asyncio
async def test_generate_content_async(oss_llm, llm_request):
async for response in oss_llm.generate_content_async(llm_request):
@@ -51,10 +107,8 @@ async def test_generate_content_async(oss_llm, llm_request):
assert response.content.parts[0].text
# Note that, this test disabled streaming because streaming is not supported
# properly in the current test model for now.
@pytest.mark.asyncio
async def test_generate_content_async_stream(oss_llm, llm_request):
async def test_generate_content_async(oss_llm, llm_request):
responses = [
resp
async for resp in oss_llm.generate_content_async(
@@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request):
]
part = responses[0].content.parts[0]
assert len(part.text) > 0
@pytest.mark.asyncio
async def test_generate_content_async_with_tools(
oss_llm, llm_request_with_tools
):
responses = [
resp
async for resp in oss_llm.generate_content_async(
llm_request_with_tools, stream=False
)
]
function_call = responses[0].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"
@pytest.mark.asyncio
async def test_generate_content_async_stream(oss_llm, llm_request):
responses = [
resp
async for resp in oss_llm.generate_content_async(llm_request, stream=True)
]
text = ""
for i in range(len(responses) - 1):
assert responses[i].partial is True
assert responses[i].content.parts[0].text
text += responses[i].content.parts[0].text
# Last message should be accumulated text
assert responses[-1].content.parts[0].text == text
assert not responses[-1].partial
@pytest.mark.asyncio
async def test_generate_content_async_stream_with_tools(
oss_llm, llm_request_with_tools
):
responses = [
resp
async for resp in oss_llm.generate_content_async(
llm_request_with_tools, stream=True
)
]
function_call = responses[-1].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"
@@ -13,7 +13,6 @@
# limitations under the License.
from google.adk.models import LlmRequest
from google.adk.models import LlmResponse
from google.adk.models.lite_llm import LiteLlm
from google.genai import types
from google.genai.types import Content
@@ -23,12 +22,11 @@ import pytest
litellm.add_function_to_prompt = True
_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"
_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"
_SYSTEM_PROMPT = """
You are a helpful assistant, and call tools optionally.
If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs.
If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs.
"""
@@ -40,7 +38,7 @@ _FUNCTIONS = [{
"properties": {
"city": {
"type": "string",
"description": "The city, e.g. San Francisco",
"description": "The city to get the weather for.",
},
},
"required": ["city"],
@@ -87,8 +85,6 @@ def llm_request():
)
# Note that, this test disabled streaming because streaming is not supported
# properly in the current test model for now.
@pytest.mark.asyncio
async def test_generate_content_asyn_with_function(
oss_llm_with_function, llm_request
@@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function(
function_call = responses[0].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"
@pytest.mark.asyncio
async def test_generate_content_asyn_stream_with_function(
oss_llm_with_function, llm_request
):
responses = [
resp
async for resp in oss_llm_with_function.generate_content_async(
llm_request, stream=True
)
]
function_call = responses[-1].content.parts[0].function_call
assert function_call.name == "get_weather"
assert function_call.args["city"] == "San Francisco"