fix: Pass drop_params to LiteLLM completion API

This lets users to specify `drop_params` when initializing `LiteLlm`, which will be forwarded to LiteLLM's `acompletion` or `completion` calls

Close #1718

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 828058105
This commit is contained in:
George Weale
2025-11-04 11:19:14 -08:00
committed by Copybara-Service
parent d4c63fc562
commit ce91a8ef73
2 changed files with 43 additions and 1 deletions
+4 -1
View File
@@ -866,10 +866,11 @@ class LiteLlm(BaseLlm):
model: The name of the LiteLlm model.
**kwargs: Additional arguments to pass to the litellm completion api.
"""
drop_params = kwargs.pop("drop_params", None)
super().__init__(model=model, **kwargs)
# Warn if using Gemini via LiteLLM
_warn_gemini_via_litellm(model)
self._additional_args = kwargs
self._additional_args = dict(kwargs)
# preventing generation call with llm_client
# and overriding messages, tools and stream which are managed internally
self._additional_args.pop("llm_client", None)
@@ -877,6 +878,8 @@ class LiteLlm(BaseLlm):
self._additional_args.pop("tools", None)
# public api called from runner determines to stream or not
self._additional_args.pop("stream", None)
if drop_params is not None:
self._additional_args["drop_params"] = drop_params
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
+39
View File
@@ -1519,6 +1519,23 @@ async def test_acompletion_additional_args(mock_acompletion, mock_client):
assert kwargs["api_base"] == "some://url"
@pytest.mark.asyncio
async def test_acompletion_with_drop_params(mock_acompletion, mock_client):
lite_llm_instance = LiteLlm(
model="test_model", llm_client=mock_client, drop_params=True
)
async for _ in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
pass
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["drop_params"] is True
@pytest.mark.asyncio
async def test_completion_additional_args(mock_completion, mock_client):
lite_llm_instance = LiteLlm(
@@ -1561,6 +1578,28 @@ async def test_completion_additional_args(mock_completion, mock_client):
assert kwargs["api_base"] == "some://url"
@pytest.mark.asyncio
async def test_completion_with_drop_params(mock_completion, mock_client):
lite_llm_instance = LiteLlm(
model="test_model", llm_client=mock_client, drop_params=True
)
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
responses = [
response
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
)
]
assert len(responses) == 4
mock_completion.assert_called_once()
_, kwargs = mock_completion.call_args
assert kwargs["drop_params"] is True
@pytest.mark.asyncio
async def test_generate_content_async_stream(
mock_completion, lite_llm_instance