feat: Add support for parsing inline JSON tool calls in LiteLLM responses

Close #1968

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 831911719
This commit is contained in:
George Weale
2025-11-13 10:16:58 -08:00
committed by Copybara-Service
parent 2efc184a46
commit 22eb7e5b06
4 changed files with 513 additions and 55 deletions
@@ -0,0 +1,17 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from . import agent
@@ -0,0 +1,174 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import datetime
import json
import re
from typing import Any
from zoneinfo import ZoneInfo
from zoneinfo import ZoneInfoNotFoundError
from google.adk.agents.llm_agent import Agent
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient
class InlineJsonToolClient(LiteLLMClient):
"""LiteLLM client that emits inline JSON tool calls for testing."""
async def acompletion(self, model, messages, tools, **kwargs):
del tools, kwargs # Only needed for API parity.
tool_message = _find_last_role(messages, role="tool")
if tool_message:
tool_summary = _coerce_to_text(tool_message.get("content"))
return {
"id": "mock-inline-tool-final-response",
"model": model,
"choices": [{
"message": {
"role": "assistant",
"content": (
f"The instrumentation tool responded with: {tool_summary}"
),
},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": 60,
"completion_tokens": 12,
"total_tokens": 72,
},
}
timezone = _extract_timezone(messages) or "Asia/Taipei"
inline_call = json.dumps(
{
"name": "get_current_time",
"arguments": {"timezone_str": timezone},
},
separators=(",", ":"),
)
return {
"id": "mock-inline-tool-call",
"model": model,
"choices": [{
"message": {
"role": "assistant",
"content": (
f"{inline_call}\nLet me double-check the clock for you."
),
},
"finish_reason": "tool_calls",
}],
"usage": {
"prompt_tokens": 45,
"completion_tokens": 15,
"total_tokens": 60,
},
}
def _find_last_role(
messages: list[dict[str, Any]], role: str
) -> dict[str, Any]:
"""Returns the last message with the given role."""
for message in reversed(messages):
if message.get("role") == role:
return message
return {}
def _coerce_to_text(content: Any) -> str:
"""Best-effort conversion from OpenAI message content to text."""
if isinstance(content, str):
return content
if isinstance(content, dict):
return _coerce_to_text(content.get("text"))
if isinstance(content, list):
texts = []
for part in content:
if isinstance(part, dict):
texts.append(part.get("text") or "")
elif isinstance(part, str):
texts.append(part)
return " ".join(text for text in texts if text)
return ""
_TIMEZONE_PATTERN = re.compile(r"([A-Za-z]+/[A-Za-z_]+)")
def _extract_timezone(messages: list[dict[str, Any]]) -> str | None:
"""Extracts an IANA timezone string from the last user message."""
user_message = _find_last_role(messages, role="user")
text = _coerce_to_text(user_message.get("content"))
if not text:
return None
match = _TIMEZONE_PATTERN.search(text)
if match:
return match.group(1)
lowered = text.lower()
if "taipei" in lowered:
return "Asia/Taipei"
if "new york" in lowered:
return "America/New_York"
if "london" in lowered:
return "Europe/London"
if "tokyo" in lowered:
return "Asia/Tokyo"
return None
def get_current_time(timezone_str: str) -> dict[str, str]:
"""Returns mock current time for the provided timezone."""
try:
tz = ZoneInfo(timezone_str)
except ZoneInfoNotFoundError as exc:
return {
"status": "error",
"report": f"Unable to parse timezone '{timezone_str}': {exc}",
}
now = datetime.datetime.now(tz)
return {
"status": "success",
"report": (
f"The current time in {timezone_str} is"
f" {now.strftime('%Y-%m-%d %H:%M:%S %Z')}."
),
}
_mock_model = LiteLlm(
model="mock/inline-json-tool-calls",
llm_client=InlineJsonToolClient(),
)
root_agent = Agent(
name="litellm_inline_tool_tester",
model=_mock_model,
description=(
"Demonstrates LiteLLM inline JSON tool-call parsing without an external"
" VLLM deployment."
),
instruction=(
"You are a deterministic clock assistant. Always call the"
" get_current_time tool before answering user questions. After the tool"
" responds, summarize what it returned."
),
tools=[get_current_time],
)
+134 -12
View File
@@ -31,6 +31,7 @@ from typing import Optional
from typing import Tuple
from typing import TypedDict
from typing import Union
import uuid
import warnings
from google.genai import types
@@ -64,6 +65,7 @@ logger = logging.getLogger("google_adk." + __name__)
_NEW_LINE = "\n"
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
_LITELLM_STRUCTURED_TYPES = {"json_object", "json_schema"}
_JSON_DECODER = json.JSONDecoder()
# Mapping of LiteLLM finish_reason strings to FinishReason enum values
# Note: tool_calls/function_call map to STOP because:
@@ -431,6 +433,118 @@ def _get_content(
return content_objects
def _build_tool_call_from_json_dict(
candidate: Any, *, index: int
) -> Optional[ChatCompletionMessageToolCall]:
"""Creates a tool call object from JSON content embedded in text."""
if not isinstance(candidate, dict):
return None
name = candidate.get("name")
args = candidate.get("arguments")
if not isinstance(name, str) or args is None:
return None
if isinstance(args, str):
arguments_payload = args
else:
try:
arguments_payload = json.dumps(args, ensure_ascii=False)
except (TypeError, ValueError):
arguments_payload = _safe_json_serialize(args)
call_id = candidate.get("id") or f"adk_tool_call_{uuid.uuid4().hex}"
call_index = candidate.get("index")
if isinstance(call_index, int):
index = call_index
function = Function(
name=name,
arguments=arguments_payload,
)
# Some LiteLLM types carry an `index` field only in streaming contexts,
# so guard the assignment to stay compatible with older versions.
if hasattr(function, "index"):
function.index = index # type: ignore[attr-defined]
tool_call = ChatCompletionMessageToolCall(
type="function",
id=str(call_id),
function=function,
)
# Same reasoning as above: not every ChatCompletionMessageToolCall exposes it.
if hasattr(tool_call, "index"):
tool_call.index = index # type: ignore[attr-defined]
return tool_call
def _parse_tool_calls_from_text(
text_block: str,
) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]:
"""Extracts inline JSON tool calls from LiteLLM text responses."""
tool_calls = []
if not text_block:
return tool_calls, None
remainder_segments = []
cursor = 0
text_length = len(text_block)
while cursor < text_length:
brace_index = text_block.find("{", cursor)
if brace_index == -1:
remainder_segments.append(text_block[cursor:])
break
remainder_segments.append(text_block[cursor:brace_index])
try:
candidate, end = _JSON_DECODER.raw_decode(text_block, brace_index)
except json.JSONDecodeError:
remainder_segments.append(text_block[brace_index])
cursor = brace_index + 1
continue
tool_call = _build_tool_call_from_json_dict(
candidate, index=len(tool_calls)
)
if tool_call:
tool_calls.append(tool_call)
else:
remainder_segments.append(text_block[brace_index:end])
cursor = end
remainder = "".join(segment for segment in remainder_segments if segment)
remainder = remainder.strip()
return tool_calls, remainder or None
def _split_message_content_and_tool_calls(
message: Message,
) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]:
"""Returns message content and tool calls, parsing inline JSON when needed."""
existing_tool_calls = message.get("tool_calls") or []
normalized_tool_calls = (
list(existing_tool_calls) if existing_tool_calls else []
)
content = message.get("content")
# LiteLLM responses either provide structured tool_calls or inline JSON, not
# both. When tool_calls are present we trust them and skip the fallback parser.
if normalized_tool_calls or not isinstance(content, str):
return content, normalized_tool_calls
fallback_tool_calls, remainder = _parse_tool_calls_from_text(content)
if fallback_tool_calls:
return remainder, fallback_tool_calls
return content, []
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
"""Converts a types.Content role to a litellm role.
@@ -584,15 +698,24 @@ def _model_response_to_chunk(
if message is None and response["choices"][0].get("delta", None):
message = response["choices"][0]["delta"]
if message.get("content", None):
yield TextChunk(text=message.get("content")), finish_reason
message_content: Optional[OpenAIMessageContent] = None
tool_calls: list[ChatCompletionMessageToolCall] = []
if message is not None:
(
message_content,
tool_calls,
) = _split_message_content_and_tool_calls(message)
if message.get("tool_calls", None):
for tool_call in message.get("tool_calls"):
if message_content:
yield TextChunk(text=message_content), finish_reason
if tool_calls:
for idx, tool_call in enumerate(tool_calls):
# aggregate tool_call
if tool_call.type == "function":
func_name = tool_call.function.name
func_args = tool_call.function.arguments
func_index = getattr(tool_call, "index", idx)
# Ignore empty chunks that don't carry any information.
if not func_name and not func_args:
@@ -602,12 +725,10 @@ def _model_response_to_chunk(
id=tool_call.id,
name=func_name,
args=func_args,
index=tool_call.index,
index=func_index,
), finish_reason
if finish_reason and not (
message.get("content", None) or message.get("tool_calls", None)
):
if finish_reason and not (message_content or tool_calls):
yield None, finish_reason
if not message:
@@ -687,11 +808,12 @@ def _message_to_generate_content_response(
"""
parts = []
if message.get("content", None):
parts.append(types.Part.from_text(text=message.get("content")))
message_content, tool_calls = _split_message_content_and_tool_calls(message)
if isinstance(message_content, str) and message_content:
parts.append(types.Part.from_text(text=message_content))
if message.get("tool_calls", None):
for tool_call in message.get("tool_calls"):
if tool_calls:
for tool_call in tool_calls:
if tool_call.type == "function":
part = types.Part.from_function_call(
name=tool_call.function.name,
+188 -43
View File
@@ -24,6 +24,8 @@ from google.adk.models.lite_llm import _get_completion_inputs
from google.adk.models.lite_llm import _get_content
from google.adk.models.lite_llm import _message_to_generate_content_response
from google.adk.models.lite_llm import _model_response_to_chunk
from google.adk.models.lite_llm import _parse_tool_calls_from_text
from google.adk.models.lite_llm import _split_message_content_and_tool_calls
from google.adk.models.lite_llm import _to_litellm_response_format
from google.adk.models.lite_llm import _to_litellm_role
from google.adk.models.lite_llm import FunctionChunk
@@ -1452,6 +1454,25 @@ def test_message_to_generate_content_response_tool_call():
assert response.content.parts[0].function_call.id == "test_tool_call_id"
def test_message_to_generate_content_response_inline_tool_call_text():
message = ChatCompletionAssistantMessage(
role="assistant",
content=(
'{"id":"inline_call","name":"get_current_time",'
'"arguments":{"timezone_str":"Asia/Taipei"}} <|im_end|>system'
),
)
response = _message_to_generate_content_response(message)
assert len(response.content.parts) == 2
text_part = response.content.parts[0]
tool_part = response.content.parts[1]
assert text_part.text == "<|im_end|>system"
assert tool_part.function_call.name == "get_current_time"
assert tool_part.function_call.args == {"timezone_str": "Asia/Taipei"}
assert tool_part.function_call.id == "inline_call"
def test_message_to_generate_content_response_with_model():
message = ChatCompletionAssistantMessage(
role="assistant",
@@ -1465,6 +1486,65 @@ def test_message_to_generate_content_response_with_model():
assert response.model_version == "gemini-2.5-pro"
def test_parse_tool_calls_from_text_multiple_calls():
text = (
'{"name":"alpha","arguments":{"value":1}}\n'
"Some filler text "
'{"id":"custom","name":"beta","arguments":{"timezone":"Asia/Taipei"}} '
"ignored suffix"
)
tool_calls, remainder = _parse_tool_calls_from_text(text)
assert len(tool_calls) == 2
assert tool_calls[0].function.name == "alpha"
assert json.loads(tool_calls[0].function.arguments) == {"value": 1}
assert tool_calls[1].id == "custom"
assert tool_calls[1].function.name == "beta"
assert json.loads(tool_calls[1].function.arguments) == {
"timezone": "Asia/Taipei"
}
assert remainder == "Some filler text ignored suffix"
def test_parse_tool_calls_from_text_invalid_json_returns_remainder():
text = 'Leading {"unused": "payload"} trailing text'
tool_calls, remainder = _parse_tool_calls_from_text(text)
assert tool_calls == []
assert remainder == 'Leading {"unused": "payload"} trailing text'
def test_split_message_content_and_tool_calls_inline_text():
message = {
"role": "assistant",
"content": (
'Intro {"name":"alpha","arguments":{"value":1}} trailing content'
),
}
content, tool_calls = _split_message_content_and_tool_calls(message)
assert content == "Intro trailing content"
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "alpha"
assert json.loads(tool_calls[0].function.arguments) == {"value": 1}
def test_split_message_content_prefers_existing_structured_calls():
tool_call = ChatCompletionMessageToolCall(
type="function",
id="existing",
function=Function(
name="existing_call",
arguments='{"arg": "value"}',
),
)
message = {
"role": "assistant",
"content": "ignored",
"tool_calls": [tool_call],
}
content, tool_calls = _split_message_content_and_tool_calls(message)
assert content == "ignored"
assert tool_calls == [tool_call]
def test_get_content_text():
parts = [types.Part.from_text(text="Test text")]
content = _get_content(parts)
@@ -1570,7 +1650,7 @@ def test_to_litellm_role():
@pytest.mark.parametrize(
"response, expected_chunks, expected_finished",
"response, expected_chunks, expected_usage_chunk, expected_finished",
[
(
ModelResponse(
@@ -1582,12 +1662,10 @@ def test_to_litellm_role():
}
]
),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
[TextChunk(text="this is a test")],
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
"stop",
),
(
@@ -1605,12 +1683,10 @@ def test_to_litellm_role():
"total_tokens": 8,
},
),
[
TextChunk(text="this is a test"),
UsageMetadataChunk(
prompt_tokens=3, completion_tokens=5, total_tokens=8
),
],
[TextChunk(text="this is a test")],
UsageMetadataChunk(
prompt_tokens=3, completion_tokens=5, total_tokens=8
),
"stop",
),
(
@@ -1635,52 +1711,121 @@ def test_to_litellm_role():
)
]
),
[
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
[FunctionChunk(id="1", name="test_function", args='{"key": "va')],
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
None,
),
(
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
],
[None],
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
"tool_calls",
),
(
ModelResponse(choices=[{}]),
[None],
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
),
"stop",
),
(
ModelResponse(
choices=[{
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": (
'{"id":"call_1","name":"get_current_time",'
'"arguments":{"timezone_str":"Asia/Taipei"}}'
),
},
}],
usage={
"prompt_tokens": 7,
"completion_tokens": 9,
"total_tokens": 16,
},
),
[
None,
UsageMetadataChunk(
prompt_tokens=0, completion_tokens=0, total_tokens=0
FunctionChunk(
id="call_1",
name="get_current_time",
args='{"timezone_str": "Asia/Taipei"}',
index=0,
),
],
"stop",
UsageMetadataChunk(
prompt_tokens=7, completion_tokens=9, total_tokens=16
),
"tool_calls",
),
(
ModelResponse(
choices=[{
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": (
'Intro {"id":"call_2","name":"alpha",'
'"arguments":{"foo":"bar"}} wrap'
),
},
}],
usage={
"prompt_tokens": 11,
"completion_tokens": 13,
"total_tokens": 24,
},
),
[
TextChunk(text="Intro wrap"),
FunctionChunk(
id="call_2",
name="alpha",
args='{"foo": "bar"}',
index=0,
),
],
UsageMetadataChunk(
prompt_tokens=11, completion_tokens=13, total_tokens=24
),
"tool_calls",
),
],
)
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
def test_model_response_to_chunk(
response, expected_chunks, expected_usage_chunk, expected_finished
):
result = list(_model_response_to_chunk(response))
assert len(result) == 2
chunk, finished = result[0]
if expected_chunks:
assert isinstance(chunk, type(expected_chunks[0]))
assert chunk == expected_chunks[0]
else:
assert chunk is None
assert finished == expected_finished
observed_chunks = []
usage_chunk = None
for chunk, finished in result:
if isinstance(chunk, UsageMetadataChunk):
usage_chunk = chunk
continue
observed_chunks.append((chunk, finished))
usage_chunk, _ = result[1]
assert usage_chunk is not None
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
assert len(observed_chunks) == len(expected_chunks)
for (chunk, finished), expected_chunk in zip(
observed_chunks, expected_chunks
):
if expected_chunk is None:
assert chunk is None
else:
assert isinstance(chunk, type(expected_chunk))
assert chunk == expected_chunk
assert finished == expected_finished
if expected_usage_chunk is None:
assert usage_chunk is None
else:
assert usage_chunk is not None
assert usage_chunk == expected_usage_chunk
@pytest.mark.asyncio