chore: Add warning for using Gemini models via LiteLLM

Recommend to use Gemini outside of LiteLLM

PiperOrigin-RevId: 800971705
This commit is contained in:
George Weale
2025-08-29 11:32:07 -07:00
committed by Copybara-Service
parent fcd748e17f
commit 9291daaa8e
2 changed files with 115 additions and 0 deletions
+66
View File
@@ -17,6 +17,8 @@ from __future__ import annotations
import base64
import json
import logging
import os
import re
from typing import Any
from typing import AsyncGenerator
from typing import cast
@@ -29,6 +31,7 @@ from typing import Optional
from typing import Tuple
from typing import TypedDict
from typing import Union
import warnings
from google.genai import types
import litellm
@@ -672,6 +675,67 @@ Functions:
"""
def _is_litellm_gemini_model(model_string: str) -> bool:
"""Check if the model is a Gemini model accessed via LiteLLM.
Args:
model_string: A LiteLLM model string (e.g., "gemini/gemini-2.5-pro" or
"vertex_ai/gemini-1.5-flash")
Returns:
True if it's a Gemini model accessed via LiteLLM, False otherwise
"""
# Matches "gemini/gemini-*" (Google AI Studio) or "vertex_ai/gemini-*" (Vertex AI).
pattern = r"^(gemini|vertex_ai)/gemini-"
return bool(re.match(pattern, model_string))
def _extract_gemini_model_from_litellm(litellm_model: str) -> str:
"""Extract the pure Gemini model name from a LiteLLM model string.
Args:
litellm_model: LiteLLM model string like "gemini/gemini-2.5-pro"
Returns:
Pure Gemini model name like "gemini-2.5-pro"
"""
# Remove LiteLLM provider prefix
if "/" in litellm_model:
return litellm_model.split("/", 1)[1]
return litellm_model
def _warn_gemini_via_litellm(model_string: str) -> None:
"""Warn if Gemini is being used via LiteLLM.
This function logs a warning suggesting users use Gemini directly rather than
through LiteLLM for better performance and features.
Args:
model_string: The LiteLLM model string to check
"""
if not _is_litellm_gemini_model(model_string):
return
# Check if warning should be suppressed via environment variable
if os.environ.get(
"ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", ""
).strip().lower() in ("1", "true", "yes", "on"):
return
warnings.warn(
f"[GEMINI_VIA_LITELLM] {model_string}: You are using Gemini via LiteLLM."
" For better performance, reliability, and access to latest features,"
" consider using Gemini directly through ADK's native Gemini"
f" integration. Replace LiteLlm(model='{model_string}') with"
f" Gemini(model='{_extract_gemini_model_from_litellm(model_string)}')."
" Set ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS=true to suppress this"
" warning.",
category=UserWarning,
stacklevel=3,
)
class LiteLlm(BaseLlm):
"""Wrapper around litellm.
@@ -708,6 +772,8 @@ class LiteLlm(BaseLlm):
**kwargs: Additional arguments to pass to the litellm completion api.
"""
super().__init__(model=model, **kwargs)
# Warn if using Gemini via LiteLLM
_warn_gemini_via_litellm(model)
self._additional_args = kwargs
# preventing generation call with llm_client
# and overriding messages, tools and stream which are managed internally
+49
View File
@@ -16,6 +16,7 @@
import json
from unittest.mock import AsyncMock
from unittest.mock import Mock
import warnings
from google.adk.models.lite_llm import _content_to_message_param
from google.adk.models.lite_llm import _function_declaration_to_tool_param
@@ -1574,3 +1575,51 @@ def test_get_completion_inputs_generation_params():
# Should not include max_output_tokens
assert "max_output_tokens" not in generation_params
assert "stop_sequences" not in generation_params
def test_gemini_via_litellm_warning(monkeypatch):
"""Test that Gemini via LiteLLM shows warning."""
# Ensure environment variable is not set
monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test with Google AI Studio Gemini via LiteLLM
LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25")
assert len(w) == 1
assert issubclass(w[0].category, UserWarning)
assert "[GEMINI_VIA_LITELLM]" in str(w[0].message)
assert "better performance" in str(w[0].message)
assert "gemini-2.5-pro-exp-03-25" in str(w[0].message)
assert "ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS" in str(w[0].message)
def test_gemini_via_litellm_warning_vertex_ai(monkeypatch):
"""Test that Vertex AI Gemini via LiteLLM shows warning."""
# Ensure environment variable is not set
monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test with Vertex AI Gemini via LiteLLM
LiteLlm(model="vertex_ai/gemini-1.5-flash")
assert len(w) == 1
assert issubclass(w[0].category, UserWarning)
assert "[GEMINI_VIA_LITELLM]" in str(w[0].message)
assert "vertex_ai/gemini-1.5-flash" in str(w[0].message)
def test_gemini_via_litellm_warning_suppressed(monkeypatch):
"""Test that Gemini via LiteLLM warning can be suppressed."""
monkeypatch.setenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", "true")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25")
assert len(w) == 0
def test_non_gemini_litellm_no_warning():
"""Test that non-Gemini models via LiteLLM don't show warning."""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test with non-Gemini model
LiteLlm(model="openai/gpt-4o")
assert len(w) == 0