diff --git a/contributing/samples/litellm_inline_tool_call/__init__.py b/contributing/samples/litellm_inline_tool_call/__init__.py new file mode 100644 index 00000000..976288f8 --- /dev/null +++ b/contributing/samples/litellm_inline_tool_call/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from . import agent diff --git a/contributing/samples/litellm_inline_tool_call/agent.py b/contributing/samples/litellm_inline_tool_call/agent.py new file mode 100644 index 00000000..94847aa8 --- /dev/null +++ b/contributing/samples/litellm_inline_tool_call/agent.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import datetime +import json +import re +from typing import Any +from zoneinfo import ZoneInfo +from zoneinfo import ZoneInfoNotFoundError + +from google.adk.agents.llm_agent import Agent +from google.adk.models.lite_llm import LiteLlm +from google.adk.models.lite_llm import LiteLLMClient + + +class InlineJsonToolClient(LiteLLMClient): + """LiteLLM client that emits inline JSON tool calls for testing.""" + + async def acompletion(self, model, messages, tools, **kwargs): + del tools, kwargs # Only needed for API parity. + + tool_message = _find_last_role(messages, role="tool") + if tool_message: + tool_summary = _coerce_to_text(tool_message.get("content")) + return { + "id": "mock-inline-tool-final-response", + "model": model, + "choices": [{ + "message": { + "role": "assistant", + "content": ( + f"The instrumentation tool responded with: {tool_summary}" + ), + }, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": 60, + "completion_tokens": 12, + "total_tokens": 72, + }, + } + + timezone = _extract_timezone(messages) or "Asia/Taipei" + inline_call = json.dumps( + { + "name": "get_current_time", + "arguments": {"timezone_str": timezone}, + }, + separators=(",", ":"), + ) + + return { + "id": "mock-inline-tool-call", + "model": model, + "choices": [{ + "message": { + "role": "assistant", + "content": ( + f"{inline_call}\nLet me double-check the clock for you." + ), + }, + "finish_reason": "tool_calls", + }], + "usage": { + "prompt_tokens": 45, + "completion_tokens": 15, + "total_tokens": 60, + }, + } + + +def _find_last_role( + messages: list[dict[str, Any]], role: str +) -> dict[str, Any]: + """Returns the last message with the given role.""" + for message in reversed(messages): + if message.get("role") == role: + return message + return {} + + +def _coerce_to_text(content: Any) -> str: + """Best-effort conversion from OpenAI message content to text.""" + if isinstance(content, str): + return content + if isinstance(content, dict): + return _coerce_to_text(content.get("text")) + if isinstance(content, list): + texts = [] + for part in content: + if isinstance(part, dict): + texts.append(part.get("text") or "") + elif isinstance(part, str): + texts.append(part) + return " ".join(text for text in texts if text) + return "" + + +_TIMEZONE_PATTERN = re.compile(r"([A-Za-z]+/[A-Za-z_]+)") + + +def _extract_timezone(messages: list[dict[str, Any]]) -> str | None: + """Extracts an IANA timezone string from the last user message.""" + user_message = _find_last_role(messages, role="user") + text = _coerce_to_text(user_message.get("content")) + if not text: + return None + match = _TIMEZONE_PATTERN.search(text) + if match: + return match.group(1) + lowered = text.lower() + if "taipei" in lowered: + return "Asia/Taipei" + if "new york" in lowered: + return "America/New_York" + if "london" in lowered: + return "Europe/London" + if "tokyo" in lowered: + return "Asia/Tokyo" + return None + + +def get_current_time(timezone_str: str) -> dict[str, str]: + """Returns mock current time for the provided timezone.""" + try: + tz = ZoneInfo(timezone_str) + except ZoneInfoNotFoundError as exc: + return { + "status": "error", + "report": f"Unable to parse timezone '{timezone_str}': {exc}", + } + now = datetime.datetime.now(tz) + return { + "status": "success", + "report": ( + f"The current time in {timezone_str} is" + f" {now.strftime('%Y-%m-%d %H:%M:%S %Z')}." + ), + } + + +_mock_model = LiteLlm( + model="mock/inline-json-tool-calls", + llm_client=InlineJsonToolClient(), +) + +root_agent = Agent( + name="litellm_inline_tool_tester", + model=_mock_model, + description=( + "Demonstrates LiteLLM inline JSON tool-call parsing without an external" + " VLLM deployment." + ), + instruction=( + "You are a deterministic clock assistant. Always call the" + " get_current_time tool before answering user questions. After the tool" + " responds, summarize what it returned." + ), + tools=[get_current_time], +) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 1eb5cb77..0c373634 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -31,6 +31,7 @@ from typing import Optional from typing import Tuple from typing import TypedDict from typing import Union +import uuid import warnings from google.genai import types @@ -64,6 +65,7 @@ logger = logging.getLogger("google_adk." + __name__) _NEW_LINE = "\n" _EXCLUDED_PART_FIELD = {"inline_data": {"data"}} _LITELLM_STRUCTURED_TYPES = {"json_object", "json_schema"} +_JSON_DECODER = json.JSONDecoder() # Mapping of LiteLLM finish_reason strings to FinishReason enum values # Note: tool_calls/function_call map to STOP because: @@ -431,6 +433,118 @@ def _get_content( return content_objects +def _build_tool_call_from_json_dict( + candidate: Any, *, index: int +) -> Optional[ChatCompletionMessageToolCall]: + """Creates a tool call object from JSON content embedded in text.""" + + if not isinstance(candidate, dict): + return None + + name = candidate.get("name") + args = candidate.get("arguments") + if not isinstance(name, str) or args is None: + return None + + if isinstance(args, str): + arguments_payload = args + else: + try: + arguments_payload = json.dumps(args, ensure_ascii=False) + except (TypeError, ValueError): + arguments_payload = _safe_json_serialize(args) + + call_id = candidate.get("id") or f"adk_tool_call_{uuid.uuid4().hex}" + call_index = candidate.get("index") + if isinstance(call_index, int): + index = call_index + + function = Function( + name=name, + arguments=arguments_payload, + ) + # Some LiteLLM types carry an `index` field only in streaming contexts, + # so guard the assignment to stay compatible with older versions. + if hasattr(function, "index"): + function.index = index # type: ignore[attr-defined] + + tool_call = ChatCompletionMessageToolCall( + type="function", + id=str(call_id), + function=function, + ) + # Same reasoning as above: not every ChatCompletionMessageToolCall exposes it. + if hasattr(tool_call, "index"): + tool_call.index = index # type: ignore[attr-defined] + + return tool_call + + +def _parse_tool_calls_from_text( + text_block: str, +) -> tuple[list[ChatCompletionMessageToolCall], Optional[str]]: + """Extracts inline JSON tool calls from LiteLLM text responses.""" + + tool_calls = [] + if not text_block: + return tool_calls, None + + remainder_segments = [] + cursor = 0 + text_length = len(text_block) + + while cursor < text_length: + brace_index = text_block.find("{", cursor) + if brace_index == -1: + remainder_segments.append(text_block[cursor:]) + break + + remainder_segments.append(text_block[cursor:brace_index]) + try: + candidate, end = _JSON_DECODER.raw_decode(text_block, brace_index) + except json.JSONDecodeError: + remainder_segments.append(text_block[brace_index]) + cursor = brace_index + 1 + continue + + tool_call = _build_tool_call_from_json_dict( + candidate, index=len(tool_calls) + ) + if tool_call: + tool_calls.append(tool_call) + else: + remainder_segments.append(text_block[brace_index:end]) + cursor = end + + remainder = "".join(segment for segment in remainder_segments if segment) + remainder = remainder.strip() + + return tool_calls, remainder or None + + +def _split_message_content_and_tool_calls( + message: Message, +) -> tuple[Optional[OpenAIMessageContent], list[ChatCompletionMessageToolCall]]: + """Returns message content and tool calls, parsing inline JSON when needed.""" + + existing_tool_calls = message.get("tool_calls") or [] + normalized_tool_calls = ( + list(existing_tool_calls) if existing_tool_calls else [] + ) + content = message.get("content") + + # LiteLLM responses either provide structured tool_calls or inline JSON, not + # both. When tool_calls are present we trust them and skip the fallback parser. + if normalized_tool_calls or not isinstance(content, str): + return content, normalized_tool_calls + + fallback_tool_calls, remainder = _parse_tool_calls_from_text(content) + if fallback_tool_calls: + return remainder, fallback_tool_calls + + return content, [] + + def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: """Converts a types.Content role to a litellm role. @@ -584,15 +698,24 @@ def _model_response_to_chunk( if message is None and response["choices"][0].get("delta", None): message = response["choices"][0]["delta"] - if message.get("content", None): - yield TextChunk(text=message.get("content")), finish_reason + message_content: Optional[OpenAIMessageContent] = None + tool_calls: list[ChatCompletionMessageToolCall] = [] + if message is not None: + ( + message_content, + tool_calls, + ) = _split_message_content_and_tool_calls(message) - if message.get("tool_calls", None): - for tool_call in message.get("tool_calls"): + if message_content: + yield TextChunk(text=message_content), finish_reason + + if tool_calls: + for idx, tool_call in enumerate(tool_calls): # aggregate tool_call if tool_call.type == "function": func_name = tool_call.function.name func_args = tool_call.function.arguments + func_index = getattr(tool_call, "index", idx) # Ignore empty chunks that don't carry any information. if not func_name and not func_args: @@ -602,12 +725,10 @@ def _model_response_to_chunk( id=tool_call.id, name=func_name, args=func_args, - index=tool_call.index, + index=func_index, ), finish_reason - if finish_reason and not ( - message.get("content", None) or message.get("tool_calls", None) - ): + if finish_reason and not (message_content or tool_calls): yield None, finish_reason if not message: @@ -687,11 +808,12 @@ def _message_to_generate_content_response( """ parts = [] - if message.get("content", None): - parts.append(types.Part.from_text(text=message.get("content"))) + message_content, tool_calls = _split_message_content_and_tool_calls(message) + if isinstance(message_content, str) and message_content: + parts.append(types.Part.from_text(text=message_content)) - if message.get("tool_calls", None): - for tool_call in message.get("tool_calls"): + if tool_calls: + for tool_call in tool_calls: if tool_call.type == "function": part = types.Part.from_function_call( name=tool_call.function.name, diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 7e4d9887..de5caae9 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -24,6 +24,8 @@ from google.adk.models.lite_llm import _get_completion_inputs from google.adk.models.lite_llm import _get_content from google.adk.models.lite_llm import _message_to_generate_content_response from google.adk.models.lite_llm import _model_response_to_chunk +from google.adk.models.lite_llm import _parse_tool_calls_from_text +from google.adk.models.lite_llm import _split_message_content_and_tool_calls from google.adk.models.lite_llm import _to_litellm_response_format from google.adk.models.lite_llm import _to_litellm_role from google.adk.models.lite_llm import FunctionChunk @@ -1452,6 +1454,25 @@ def test_message_to_generate_content_response_tool_call(): assert response.content.parts[0].function_call.id == "test_tool_call_id" +def test_message_to_generate_content_response_inline_tool_call_text(): + message = ChatCompletionAssistantMessage( + role="assistant", + content=( + '{"id":"inline_call","name":"get_current_time",' + '"arguments":{"timezone_str":"Asia/Taipei"}} <|im_end|>system' + ), + ) + + response = _message_to_generate_content_response(message) + assert len(response.content.parts) == 2 + text_part = response.content.parts[0] + tool_part = response.content.parts[1] + assert text_part.text == "<|im_end|>system" + assert tool_part.function_call.name == "get_current_time" + assert tool_part.function_call.args == {"timezone_str": "Asia/Taipei"} + assert tool_part.function_call.id == "inline_call" + + def test_message_to_generate_content_response_with_model(): message = ChatCompletionAssistantMessage( role="assistant", @@ -1465,6 +1486,65 @@ def test_message_to_generate_content_response_with_model(): assert response.model_version == "gemini-2.5-pro" +def test_parse_tool_calls_from_text_multiple_calls(): + text = ( + '{"name":"alpha","arguments":{"value":1}}\n' + "Some filler text " + '{"id":"custom","name":"beta","arguments":{"timezone":"Asia/Taipei"}} ' + "ignored suffix" + ) + tool_calls, remainder = _parse_tool_calls_from_text(text) + assert len(tool_calls) == 2 + assert tool_calls[0].function.name == "alpha" + assert json.loads(tool_calls[0].function.arguments) == {"value": 1} + assert tool_calls[1].id == "custom" + assert tool_calls[1].function.name == "beta" + assert json.loads(tool_calls[1].function.arguments) == { + "timezone": "Asia/Taipei" + } + assert remainder == "Some filler text ignored suffix" + + +def test_parse_tool_calls_from_text_invalid_json_returns_remainder(): + text = 'Leading {"unused": "payload"} trailing text' + tool_calls, remainder = _parse_tool_calls_from_text(text) + assert tool_calls == [] + assert remainder == 'Leading {"unused": "payload"} trailing text' + + +def test_split_message_content_and_tool_calls_inline_text(): + message = { + "role": "assistant", + "content": ( + 'Intro {"name":"alpha","arguments":{"value":1}} trailing content' + ), + } + content, tool_calls = _split_message_content_and_tool_calls(message) + assert content == "Intro trailing content" + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "alpha" + assert json.loads(tool_calls[0].function.arguments) == {"value": 1} + + +def test_split_message_content_prefers_existing_structured_calls(): + tool_call = ChatCompletionMessageToolCall( + type="function", + id="existing", + function=Function( + name="existing_call", + arguments='{"arg": "value"}', + ), + ) + message = { + "role": "assistant", + "content": "ignored", + "tool_calls": [tool_call], + } + content, tool_calls = _split_message_content_and_tool_calls(message) + assert content == "ignored" + assert tool_calls == [tool_call] + + def test_get_content_text(): parts = [types.Part.from_text(text="Test text")] content = _get_content(parts) @@ -1570,7 +1650,7 @@ def test_to_litellm_role(): @pytest.mark.parametrize( - "response, expected_chunks, expected_finished", + "response, expected_chunks, expected_usage_chunk, expected_finished", [ ( ModelResponse( @@ -1582,12 +1662,10 @@ def test_to_litellm_role(): } ] ), - [ - TextChunk(text="this is a test"), - UsageMetadataChunk( - prompt_tokens=0, completion_tokens=0, total_tokens=0 - ), - ], + [TextChunk(text="this is a test")], + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), "stop", ), ( @@ -1605,12 +1683,10 @@ def test_to_litellm_role(): "total_tokens": 8, }, ), - [ - TextChunk(text="this is a test"), - UsageMetadataChunk( - prompt_tokens=3, completion_tokens=5, total_tokens=8 - ), - ], + [TextChunk(text="this is a test")], + UsageMetadataChunk( + prompt_tokens=3, completion_tokens=5, total_tokens=8 + ), "stop", ), ( @@ -1635,52 +1711,121 @@ def test_to_litellm_role(): ) ] ), - [ - FunctionChunk(id="1", name="test_function", args='{"key": "va'), - UsageMetadataChunk( - prompt_tokens=0, completion_tokens=0, total_tokens=0 - ), - ], + [FunctionChunk(id="1", name="test_function", args='{"key": "va')], + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), None, ), ( ModelResponse(choices=[{"finish_reason": "tool_calls"}]), - [ - None, - UsageMetadataChunk( - prompt_tokens=0, completion_tokens=0, total_tokens=0 - ), - ], + [None], + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), "tool_calls", ), ( ModelResponse(choices=[{}]), + [None], + UsageMetadataChunk( + prompt_tokens=0, completion_tokens=0, total_tokens=0 + ), + "stop", + ), + ( + ModelResponse( + choices=[{ + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": ( + '{"id":"call_1","name":"get_current_time",' + '"arguments":{"timezone_str":"Asia/Taipei"}}' + ), + }, + }], + usage={ + "prompt_tokens": 7, + "completion_tokens": 9, + "total_tokens": 16, + }, + ), [ - None, - UsageMetadataChunk( - prompt_tokens=0, completion_tokens=0, total_tokens=0 + FunctionChunk( + id="call_1", + name="get_current_time", + args='{"timezone_str": "Asia/Taipei"}', + index=0, ), ], - "stop", + UsageMetadataChunk( + prompt_tokens=7, completion_tokens=9, total_tokens=16 + ), + "tool_calls", + ), + ( + ModelResponse( + choices=[{ + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": ( + 'Intro {"id":"call_2","name":"alpha",' + '"arguments":{"foo":"bar"}} wrap' + ), + }, + }], + usage={ + "prompt_tokens": 11, + "completion_tokens": 13, + "total_tokens": 24, + }, + ), + [ + TextChunk(text="Intro wrap"), + FunctionChunk( + id="call_2", + name="alpha", + args='{"foo": "bar"}', + index=0, + ), + ], + UsageMetadataChunk( + prompt_tokens=11, completion_tokens=13, total_tokens=24 + ), + "tool_calls", ), ], ) -def test_model_response_to_chunk(response, expected_chunks, expected_finished): +def test_model_response_to_chunk( + response, expected_chunks, expected_usage_chunk, expected_finished +): result = list(_model_response_to_chunk(response)) - assert len(result) == 2 - chunk, finished = result[0] - if expected_chunks: - assert isinstance(chunk, type(expected_chunks[0])) - assert chunk == expected_chunks[0] - else: - assert chunk is None - assert finished == expected_finished + observed_chunks = [] + usage_chunk = None + for chunk, finished in result: + if isinstance(chunk, UsageMetadataChunk): + usage_chunk = chunk + continue + observed_chunks.append((chunk, finished)) - usage_chunk, _ = result[1] - assert usage_chunk is not None - assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens - assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens - assert usage_chunk.total_tokens == expected_chunks[1].total_tokens + assert len(observed_chunks) == len(expected_chunks) + for (chunk, finished), expected_chunk in zip( + observed_chunks, expected_chunks + ): + if expected_chunk is None: + assert chunk is None + else: + assert isinstance(chunk, type(expected_chunk)) + assert chunk == expected_chunk + assert finished == expected_finished + + if expected_usage_chunk is None: + assert usage_chunk is None + else: + assert usage_chunk is not None + assert usage_chunk == expected_usage_chunk @pytest.mark.asyncio