You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: stream in litellm + adk and add corresponding integration tests
Fixes https://github.com/google/adk-python/issues/1368 PiperOrigin-RevId: 772218385
This commit is contained in:
committed by
Copybara-Service
parent
4bda245171
commit
aafa80bd85
@@ -739,11 +739,12 @@ class LiteLlm(BaseLlm):
|
||||
_message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
content=text,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
)
|
||||
)
|
||||
text = ""
|
||||
function_calls.clear()
|
||||
elif finish_reason == "stop" and text:
|
||||
aggregated_llm_response = _message_to_generate_content_response(
|
||||
|
||||
@@ -20,12 +20,26 @@ from google.genai.types import Content
|
||||
from google.genai.types import Part
|
||||
import pytest
|
||||
|
||||
_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"
|
||||
|
||||
_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"
|
||||
|
||||
_SYSTEM_PROMPT = """You are a helpful assistant."""
|
||||
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
"""Simulates a web search. Use it get information on weather.
|
||||
|
||||
Args:
|
||||
city: A string containing the location to get weather information for.
|
||||
|
||||
Returns:
|
||||
A string with the simulated weather information for the queried city.
|
||||
"""
|
||||
if "sf" in city.lower() or "san francisco" in city.lower():
|
||||
return "It's 70 degrees and foggy."
|
||||
return "It's 80 degrees and sunny."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oss_llm():
|
||||
return LiteLlm(model=_TEST_MODEL_NAME)
|
||||
@@ -44,6 +58,48 @@ def llm_request():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_request_with_tools():
|
||||
return LlmRequest(
|
||||
model=_TEST_MODEL_NAME,
|
||||
contents=[
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text="What is the weather in San Francisco?")
|
||||
],
|
||||
)
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.1,
|
||||
response_modalities=[types.Modality.TEXT],
|
||||
system_instruction=_SYSTEM_PROMPT,
|
||||
tools=[
|
||||
types.Tool(
|
||||
function_declarations=[
|
||||
types.FunctionDeclaration(
|
||||
name="get_weather",
|
||||
description="Get the weather in a given location",
|
||||
parameters=types.Schema(
|
||||
type=types.Type.OBJECT,
|
||||
properties={
|
||||
"city": types.Schema(
|
||||
type=types.Type.STRING,
|
||||
description=(
|
||||
"The city to get the weather for."
|
||||
),
|
||||
),
|
||||
},
|
||||
required=["city"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async(oss_llm, llm_request):
|
||||
async for response in oss_llm.generate_content_async(llm_request):
|
||||
@@ -51,10 +107,8 @@ async def test_generate_content_async(oss_llm, llm_request):
|
||||
assert response.content.parts[0].text
|
||||
|
||||
|
||||
# Note that, this test disabled streaming because streaming is not supported
|
||||
# properly in the current test model for now.
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream(oss_llm, llm_request):
|
||||
async def test_generate_content_async(oss_llm, llm_request):
|
||||
responses = [
|
||||
resp
|
||||
async for resp in oss_llm.generate_content_async(
|
||||
@@ -63,3 +117,50 @@ async def test_generate_content_async_stream(oss_llm, llm_request):
|
||||
]
|
||||
part = responses[0].content.parts[0]
|
||||
assert len(part.text) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_tools(
|
||||
oss_llm, llm_request_with_tools
|
||||
):
|
||||
responses = [
|
||||
resp
|
||||
async for resp in oss_llm.generate_content_async(
|
||||
llm_request_with_tools, stream=False
|
||||
)
|
||||
]
|
||||
function_call = responses[0].content.parts[0].function_call
|
||||
assert function_call.name == "get_weather"
|
||||
assert function_call.args["city"] == "San Francisco"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream(oss_llm, llm_request):
|
||||
responses = [
|
||||
resp
|
||||
async for resp in oss_llm.generate_content_async(llm_request, stream=True)
|
||||
]
|
||||
text = ""
|
||||
for i in range(len(responses) - 1):
|
||||
assert responses[i].partial is True
|
||||
assert responses[i].content.parts[0].text
|
||||
text += responses[i].content.parts[0].text
|
||||
|
||||
# Last message should be accumulated text
|
||||
assert responses[-1].content.parts[0].text == text
|
||||
assert not responses[-1].partial
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream_with_tools(
|
||||
oss_llm, llm_request_with_tools
|
||||
):
|
||||
responses = [
|
||||
resp
|
||||
async for resp in oss_llm.generate_content_async(
|
||||
llm_request_with_tools, stream=True
|
||||
)
|
||||
]
|
||||
function_call = responses[-1].content.parts[0].function_call
|
||||
assert function_call.name == "get_weather"
|
||||
assert function_call.args["city"] == "San Francisco"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from google.adk.models import LlmRequest
|
||||
from google.adk.models import LlmResponse
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.genai import types
|
||||
from google.genai.types import Content
|
||||
@@ -23,12 +22,11 @@ import pytest
|
||||
|
||||
litellm.add_function_to_prompt = True
|
||||
|
||||
_TEST_MODEL_NAME = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas"
|
||||
|
||||
_TEST_MODEL_NAME = "vertex_ai/meta/llama-3.1-405b-instruct-maas"
|
||||
|
||||
_SYSTEM_PROMPT = """
|
||||
You are a helpful assistant, and call tools optionally.
|
||||
If call tools, the tool format should be in json, and the tool arguments should be parsed from users inputs.
|
||||
If call tools, the tool format should be in json body, and the tool argument values should be parsed from users inputs.
|
||||
"""
|
||||
|
||||
|
||||
@@ -40,7 +38,7 @@ _FUNCTIONS = [{
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city, e.g. San Francisco",
|
||||
"description": "The city to get the weather for.",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
@@ -87,8 +85,6 @@ def llm_request():
|
||||
)
|
||||
|
||||
|
||||
# Note that, this test disabled streaming because streaming is not supported
|
||||
# properly in the current test model for now.
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_asyn_with_function(
|
||||
oss_llm_with_function, llm_request
|
||||
@@ -102,3 +98,18 @@ async def test_generate_content_asyn_with_function(
|
||||
function_call = responses[0].content.parts[0].function_call
|
||||
assert function_call.name == "get_weather"
|
||||
assert function_call.args["city"] == "San Francisco"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_asyn_stream_with_function(
|
||||
oss_llm_with_function, llm_request
|
||||
):
|
||||
responses = [
|
||||
resp
|
||||
async for resp in oss_llm_with_function.generate_content_async(
|
||||
llm_request, stream=True
|
||||
)
|
||||
]
|
||||
function_call = responses[-1].content.parts[0].function_call
|
||||
assert function_call.name == "get_weather"
|
||||
assert function_call.args["city"] == "San Francisco"
|
||||
|
||||
Reference in New Issue
Block a user