You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add support for parsing inline JSON tool calls in LiteLLM responses
Close #1968 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 831911719
This commit is contained in:
committed by
Copybara-Service
parent
2efc184a46
commit
22eb7e5b06
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,174 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo
|
||||
from zoneinfo import ZoneInfoNotFoundError
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.adk.models.lite_llm import LiteLLMClient
|
||||
|
||||
|
||||
class InlineJsonToolClient(LiteLLMClient):
|
||||
"""LiteLLM client that emits inline JSON tool calls for testing."""
|
||||
|
||||
async def acompletion(self, model, messages, tools, **kwargs):
|
||||
del tools, kwargs # Only needed for API parity.
|
||||
|
||||
tool_message = _find_last_role(messages, role="tool")
|
||||
if tool_message:
|
||||
tool_summary = _coerce_to_text(tool_message.get("content"))
|
||||
return {
|
||||
"id": "mock-inline-tool-final-response",
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
f"The instrumentation tool responded with: {tool_summary}"
|
||||
),
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 60,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 72,
|
||||
},
|
||||
}
|
||||
|
||||
timezone = _extract_timezone(messages) or "Asia/Taipei"
|
||||
inline_call = json.dumps(
|
||||
{
|
||||
"name": "get_current_time",
|
||||
"arguments": {"timezone_str": timezone},
|
||||
},
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
return {
|
||||
"id": "mock-inline-tool-call",
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
f"{inline_call}\nLet me double-check the clock for you."
|
||||
),
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 45,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 60,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _find_last_role(
|
||||
messages: list[dict[str, Any]], role: str
|
||||
) -> dict[str, Any]:
|
||||
"""Returns the last message with the given role."""
|
||||
for message in reversed(messages):
|
||||
if message.get("role") == role:
|
||||
return message
|
||||
return {}
|
||||
|
||||
|
||||
def _coerce_to_text(content: Any) -> str:
|
||||
"""Best-effort conversion from OpenAI message content to text."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, dict):
|
||||
return _coerce_to_text(content.get("text"))
|
||||
if isinstance(content, list):
|
||||
texts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
texts.append(part.get("text") or "")
|
||||
elif isinstance(part, str):
|
||||
texts.append(part)
|
||||
return " ".join(text for text in texts if text)
|
||||
return ""
|
||||
|
||||
|
||||
_TIMEZONE_PATTERN = re.compile(r"([A-Za-z]+/[A-Za-z_]+)")
|
||||
|
||||
|
||||
def _extract_timezone(messages: list[dict[str, Any]]) -> str | None:
|
||||
"""Extracts an IANA timezone string from the last user message."""
|
||||
user_message = _find_last_role(messages, role="user")
|
||||
text = _coerce_to_text(user_message.get("content"))
|
||||
if not text:
|
||||
return None
|
||||
match = _TIMEZONE_PATTERN.search(text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
lowered = text.lower()
|
||||
if "taipei" in lowered:
|
||||
return "Asia/Taipei"
|
||||
if "new york" in lowered:
|
||||
return "America/New_York"
|
||||
if "london" in lowered:
|
||||
return "Europe/London"
|
||||
if "tokyo" in lowered:
|
||||
return "Asia/Tokyo"
|
||||
return None
|
||||
|
||||
|
||||
def get_current_time(timezone_str: str) -> dict[str, str]:
|
||||
"""Returns mock current time for the provided timezone."""
|
||||
try:
|
||||
tz = ZoneInfo(timezone_str)
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
return {
|
||||
"status": "error",
|
||||
"report": f"Unable to parse timezone '{timezone_str}': {exc}",
|
||||
}
|
||||
now = datetime.datetime.now(tz)
|
||||
return {
|
||||
"status": "success",
|
||||
"report": (
|
||||
f"The current time in {timezone_str} is"
|
||||
f" {now.strftime('%Y-%m-%d %H:%M:%S %Z')}."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
_mock_model = LiteLlm(
|
||||
model="mock/inline-json-tool-calls",
|
||||
llm_client=InlineJsonToolClient(),
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
name="litellm_inline_tool_tester",
|
||||
model=_mock_model,
|
||||
description=(
|
||||
"Demonstrates LiteLLM inline JSON tool-call parsing without an external"
|
||||
" VLLM deployment."
|
||||
),
|
||||
instruction=(
|
||||
"You are a deterministic clock assistant. Always call the"
|
||||
" get_current_time tool before answering user questions. After the tool"
|
||||
" responds, summarize what it returned."
|
||||
),
|
||||
tools=[get_current_time],
|
||||
)
|
||||
@@ -31,6 +31,7 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypedDict
|
||||
from typing import Union
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from google.genai import types
|
||||
@@ -64,6 +65,7 @@ logger = logging.getLogger("google_adk." + __name__)
|
||||
_NEW_LINE = "\n"
|
||||
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
|
||||
_LITELLM_STRUCTURED_TYPES = {"json_object", "json_schema"}
|
||||
_JSON_DECODER = json.JSONDecoder()
|
||||
|
||||
# Mapping of LiteLLM finish_reason strings to FinishReason enum values
|
||||
# Note: tool_calls/function_call map to STOP because:
|
||||
@@ -431,6 +433,118 @@ def _get_content(
|
||||
return content_objects
|
||||
|
||||
|
||||
def _build_tool_call_from_json_dict(
|
||||
candidate: Any, *, index: int
|
||||
) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""Creates a tool call object from JSON content embedded in text."""
|
||||
|
||||
if not isinstance(candidate, dict):
|
||||
return None
|
||||
|
||||
name = candidate.get("name")
|
||||
args = candidate.get("arguments")
|
||||
if not isinstance(name, str) or args is None:
|
||||
return None
|
||||
|
||||
if isinstance(args, str):
|
||||
arguments_payload = args
|
||||
else:
|
||||
try:
|
||||
arguments_payload = json.dumps(args, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
arguments_payload = _safe_json_serialize(args)
|
||||
|
||||
call_id = candidate.get("id") or f"adk_tool_call_{uuid.uuid4().hex}"
|
||||
call_index = candidate.get("index")
|
||||
if isinstance(call_index, int):
|
||||
index = call_index
|
||||
|
||||
function = Function(
|
||||
name=name,
|
||||
arguments=arguments_payload,
|
||||
)
|
||||
# Some LiteLLM types carry an `index` field only in streaming contexts,
|
||||
# so guard the assignment to stay compatible with older versions.
|
||||
if hasattr(function, "index"):
|
||||
function.index = index # type: ignore[attr-defined]
|
||||
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=str(call_id),
|
||||
function=function,
|
||||
)
|
||||
# Same reasoning as above: not every ChatCompletionMessageToolCall exposes it.
|
||||
if hasattr(tool_call, "index"):
|
||||
tool_call.index = index # type: ignore[attr-defined]
|
||||
|
||||
return tool_call
|
||||
|
||||
|
||||
def _parse_tool_calls_from_text(
|
||||
text_block: str,
|
||||
) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]:
|
||||
"""Extracts inline JSON tool calls from LiteLLM text responses."""
|
||||
|
||||
tool_calls = []
|
||||
if not text_block:
|
||||
return tool_calls, None
|
||||
|
||||
remainder_segments = []
|
||||
cursor = 0
|
||||
text_length = len(text_block)
|
||||
|
||||
while cursor < text_length:
|
||||
brace_index = text_block.find("{", cursor)
|
||||
if brace_index == -1:
|
||||
remainder_segments.append(text_block[cursor:])
|
||||
break
|
||||
|
||||
remainder_segments.append(text_block[cursor:brace_index])
|
||||
try:
|
||||
candidate, end = _JSON_DECODER.raw_decode(text_block, brace_index)
|
||||
except json.JSONDecodeError:
|
||||
remainder_segments.append(text_block[brace_index])
|
||||
cursor = brace_index + 1
|
||||
continue
|
||||
|
||||
tool_call = _build_tool_call_from_json_dict(
|
||||
candidate, index=len(tool_calls)
|
||||
)
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
else:
|
||||
remainder_segments.append(text_block[brace_index:end])
|
||||
cursor = end
|
||||
|
||||
remainder = "".join(segment for segment in remainder_segments if segment)
|
||||
remainder = remainder.strip()
|
||||
|
||||
return tool_calls, remainder or None
|
||||
|
||||
|
||||
def _split_message_content_and_tool_calls(
|
||||
message: Message,
|
||||
) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]:
|
||||
"""Returns message content and tool calls, parsing inline JSON when needed."""
|
||||
|
||||
existing_tool_calls = message.get("tool_calls") or []
|
||||
normalized_tool_calls = (
|
||||
list(existing_tool_calls) if existing_tool_calls else []
|
||||
)
|
||||
content = message.get("content")
|
||||
|
||||
# LiteLLM responses either provide structured tool_calls or inline JSON, not
|
||||
# both. When tool_calls are present we trust them and skip the fallback parser.
|
||||
if normalized_tool_calls or not isinstance(content, str):
|
||||
return content, normalized_tool_calls
|
||||
|
||||
fallback_tool_calls, remainder = _parse_tool_calls_from_text(content)
|
||||
if fallback_tool_calls:
|
||||
return remainder, fallback_tool_calls
|
||||
|
||||
return content, []
|
||||
|
||||
|
||||
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
"""Converts a types.Content role to a litellm role.
|
||||
|
||||
@@ -584,15 +698,24 @@ def _model_response_to_chunk(
|
||||
if message is None and response["choices"][0].get("delta", None):
|
||||
message = response["choices"][0]["delta"]
|
||||
|
||||
if message.get("content", None):
|
||||
yield TextChunk(text=message.get("content")), finish_reason
|
||||
message_content: Optional[OpenAIMessageContent] = None
|
||||
tool_calls: list[ChatCompletionMessageToolCall] = []
|
||||
if message is not None:
|
||||
(
|
||||
message_content,
|
||||
tool_calls,
|
||||
) = _split_message_content_and_tool_calls(message)
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
if message_content:
|
||||
yield TextChunk(text=message_content), finish_reason
|
||||
|
||||
if tool_calls:
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
# aggregate tool_call
|
||||
if tool_call.type == "function":
|
||||
func_name = tool_call.function.name
|
||||
func_args = tool_call.function.arguments
|
||||
func_index = getattr(tool_call, "index", idx)
|
||||
|
||||
# Ignore empty chunks that don't carry any information.
|
||||
if not func_name and not func_args:
|
||||
@@ -602,12 +725,10 @@ def _model_response_to_chunk(
|
||||
id=tool_call.id,
|
||||
name=func_name,
|
||||
args=func_args,
|
||||
index=tool_call.index,
|
||||
index=func_index,
|
||||
), finish_reason
|
||||
|
||||
if finish_reason and not (
|
||||
message.get("content", None) or message.get("tool_calls", None)
|
||||
):
|
||||
if finish_reason and not (message_content or tool_calls):
|
||||
yield None, finish_reason
|
||||
|
||||
if not message:
|
||||
@@ -687,11 +808,12 @@ def _message_to_generate_content_response(
|
||||
"""
|
||||
|
||||
parts = []
|
||||
if message.get("content", None):
|
||||
parts.append(types.Part.from_text(text=message.get("content")))
|
||||
message_content, tool_calls = _split_message_content_and_tool_calls(message)
|
||||
if isinstance(message_content, str) and message_content:
|
||||
parts.append(types.Part.from_text(text=message_content))
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.type == "function":
|
||||
part = types.Part.from_function_call(
|
||||
name=tool_call.function.name,
|
||||
|
||||
@@ -24,6 +24,8 @@ from google.adk.models.lite_llm import _get_completion_inputs
|
||||
from google.adk.models.lite_llm import _get_content
|
||||
from google.adk.models.lite_llm import _message_to_generate_content_response
|
||||
from google.adk.models.lite_llm import _model_response_to_chunk
|
||||
from google.adk.models.lite_llm import _parse_tool_calls_from_text
|
||||
from google.adk.models.lite_llm import _split_message_content_and_tool_calls
|
||||
from google.adk.models.lite_llm import _to_litellm_response_format
|
||||
from google.adk.models.lite_llm import _to_litellm_role
|
||||
from google.adk.models.lite_llm import FunctionChunk
|
||||
@@ -1452,6 +1454,25 @@ def test_message_to_generate_content_response_tool_call():
|
||||
assert response.content.parts[0].function_call.id == "test_tool_call_id"
|
||||
|
||||
|
||||
def test_message_to_generate_content_response_inline_tool_call_text():
|
||||
message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=(
|
||||
'{"id":"inline_call","name":"get_current_time",'
|
||||
'"arguments":{"timezone_str":"Asia/Taipei"}} <|im_end|>system'
|
||||
),
|
||||
)
|
||||
|
||||
response = _message_to_generate_content_response(message)
|
||||
assert len(response.content.parts) == 2
|
||||
text_part = response.content.parts[0]
|
||||
tool_part = response.content.parts[1]
|
||||
assert text_part.text == "<|im_end|>system"
|
||||
assert tool_part.function_call.name == "get_current_time"
|
||||
assert tool_part.function_call.args == {"timezone_str": "Asia/Taipei"}
|
||||
assert tool_part.function_call.id == "inline_call"
|
||||
|
||||
|
||||
def test_message_to_generate_content_response_with_model():
|
||||
message = ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
@@ -1465,6 +1486,65 @@ def test_message_to_generate_content_response_with_model():
|
||||
assert response.model_version == "gemini-2.5-pro"
|
||||
|
||||
|
||||
def test_parse_tool_calls_from_text_multiple_calls():
|
||||
text = (
|
||||
'{"name":"alpha","arguments":{"value":1}}\n'
|
||||
"Some filler text "
|
||||
'{"id":"custom","name":"beta","arguments":{"timezone":"Asia/Taipei"}} '
|
||||
"ignored suffix"
|
||||
)
|
||||
tool_calls, remainder = _parse_tool_calls_from_text(text)
|
||||
assert len(tool_calls) == 2
|
||||
assert tool_calls[0].function.name == "alpha"
|
||||
assert json.loads(tool_calls[0].function.arguments) == {"value": 1}
|
||||
assert tool_calls[1].id == "custom"
|
||||
assert tool_calls[1].function.name == "beta"
|
||||
assert json.loads(tool_calls[1].function.arguments) == {
|
||||
"timezone": "Asia/Taipei"
|
||||
}
|
||||
assert remainder == "Some filler text ignored suffix"
|
||||
|
||||
|
||||
def test_parse_tool_calls_from_text_invalid_json_returns_remainder():
|
||||
text = 'Leading {"unused": "payload"} trailing text'
|
||||
tool_calls, remainder = _parse_tool_calls_from_text(text)
|
||||
assert tool_calls == []
|
||||
assert remainder == 'Leading {"unused": "payload"} trailing text'
|
||||
|
||||
|
||||
def test_split_message_content_and_tool_calls_inline_text():
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
'Intro {"name":"alpha","arguments":{"value":1}} trailing content'
|
||||
),
|
||||
}
|
||||
content, tool_calls = _split_message_content_and_tool_calls(message)
|
||||
assert content == "Intro trailing content"
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "alpha"
|
||||
assert json.loads(tool_calls[0].function.arguments) == {"value": 1}
|
||||
|
||||
|
||||
def test_split_message_content_prefers_existing_structured_calls():
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id="existing",
|
||||
function=Function(
|
||||
name="existing_call",
|
||||
arguments='{"arg": "value"}',
|
||||
),
|
||||
)
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": "ignored",
|
||||
"tool_calls": [tool_call],
|
||||
}
|
||||
content, tool_calls = _split_message_content_and_tool_calls(message)
|
||||
assert content == "ignored"
|
||||
assert tool_calls == [tool_call]
|
||||
|
||||
|
||||
def test_get_content_text():
|
||||
parts = [types.Part.from_text(text="Test text")]
|
||||
content = _get_content(parts)
|
||||
@@ -1570,7 +1650,7 @@ def test_to_litellm_role():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected_chunks, expected_finished",
|
||||
"response, expected_chunks, expected_usage_chunk, expected_finished",
|
||||
[
|
||||
(
|
||||
ModelResponse(
|
||||
@@ -1582,12 +1662,10 @@ def test_to_litellm_role():
|
||||
}
|
||||
]
|
||||
),
|
||||
[
|
||||
TextChunk(text="this is a test"),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
[TextChunk(text="this is a test")],
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
@@ -1605,12 +1683,10 @@ def test_to_litellm_role():
|
||||
"total_tokens": 8,
|
||||
},
|
||||
),
|
||||
[
|
||||
TextChunk(text="this is a test"),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=3, completion_tokens=5, total_tokens=8
|
||||
),
|
||||
],
|
||||
[TextChunk(text="this is a test")],
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=3, completion_tokens=5, total_tokens=8
|
||||
),
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
@@ -1635,52 +1711,121 @@ def test_to_litellm_role():
|
||||
)
|
||||
]
|
||||
),
|
||||
[
|
||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
[FunctionChunk(id="1", name="test_function", args='{"key": "va')],
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
||||
[
|
||||
None,
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
[None],
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
"tool_calls",
|
||||
),
|
||||
(
|
||||
ModelResponse(choices=[{}]),
|
||||
[None],
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
ModelResponse(
|
||||
choices=[{
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
'{"id":"call_1","name":"get_current_time",'
|
||||
'"arguments":{"timezone_str":"Asia/Taipei"}}'
|
||||
),
|
||||
},
|
||||
}],
|
||||
usage={
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 9,
|
||||
"total_tokens": 16,
|
||||
},
|
||||
),
|
||||
[
|
||||
None,
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
FunctionChunk(
|
||||
id="call_1",
|
||||
name="get_current_time",
|
||||
args='{"timezone_str": "Asia/Taipei"}',
|
||||
index=0,
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=7, completion_tokens=9, total_tokens=16
|
||||
),
|
||||
"tool_calls",
|
||||
),
|
||||
(
|
||||
ModelResponse(
|
||||
choices=[{
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
'Intro {"id":"call_2","name":"alpha",'
|
||||
'"arguments":{"foo":"bar"}} wrap'
|
||||
),
|
||||
},
|
||||
}],
|
||||
usage={
|
||||
"prompt_tokens": 11,
|
||||
"completion_tokens": 13,
|
||||
"total_tokens": 24,
|
||||
},
|
||||
),
|
||||
[
|
||||
TextChunk(text="Intro wrap"),
|
||||
FunctionChunk(
|
||||
id="call_2",
|
||||
name="alpha",
|
||||
args='{"foo": "bar"}',
|
||||
index=0,
|
||||
),
|
||||
],
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=11, completion_tokens=13, total_tokens=24
|
||||
),
|
||||
"tool_calls",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
|
||||
def test_model_response_to_chunk(
|
||||
response, expected_chunks, expected_usage_chunk, expected_finished
|
||||
):
|
||||
result = list(_model_response_to_chunk(response))
|
||||
assert len(result) == 2
|
||||
chunk, finished = result[0]
|
||||
if expected_chunks:
|
||||
assert isinstance(chunk, type(expected_chunks[0]))
|
||||
assert chunk == expected_chunks[0]
|
||||
else:
|
||||
assert chunk is None
|
||||
assert finished == expected_finished
|
||||
observed_chunks = []
|
||||
usage_chunk = None
|
||||
for chunk, finished in result:
|
||||
if isinstance(chunk, UsageMetadataChunk):
|
||||
usage_chunk = chunk
|
||||
continue
|
||||
observed_chunks.append((chunk, finished))
|
||||
|
||||
usage_chunk, _ = result[1]
|
||||
assert usage_chunk is not None
|
||||
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
|
||||
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
|
||||
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
|
||||
assert len(observed_chunks) == len(expected_chunks)
|
||||
for (chunk, finished), expected_chunk in zip(
|
||||
observed_chunks, expected_chunks
|
||||
):
|
||||
if expected_chunk is None:
|
||||
assert chunk is None
|
||||
else:
|
||||
assert isinstance(chunk, type(expected_chunk))
|
||||
assert chunk == expected_chunk
|
||||
assert finished == expected_finished
|
||||
|
||||
if expected_usage_chunk is None:
|
||||
assert usage_chunk is None
|
||||
else:
|
||||
assert usage_chunk is not None
|
||||
assert usage_chunk == expected_usage_chunk
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user