You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Add warning for using Gemini models via LiteLLM
Recommend to use Gemini outside of LiteLLM PiperOrigin-RevId: 800971705
This commit is contained in:
committed by
Copybara-Service
parent
fcd748e17f
commit
9291daaa8e
@@ -17,6 +17,8 @@ from __future__ import annotations
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
@@ -29,6 +31,7 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypedDict
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
from google.genai import types
|
||||
import litellm
|
||||
@@ -672,6 +675,67 @@ Functions:
|
||||
"""
|
||||
|
||||
|
||||
def _is_litellm_gemini_model(model_string: str) -> bool:
|
||||
"""Check if the model is a Gemini model accessed via LiteLLM.
|
||||
|
||||
Args:
|
||||
model_string: A LiteLLM model string (e.g., "gemini/gemini-2.5-pro" or
|
||||
"vertex_ai/gemini-1.5-flash")
|
||||
|
||||
Returns:
|
||||
True if it's a Gemini model accessed via LiteLLM, False otherwise
|
||||
"""
|
||||
# Matches "gemini/gemini-*" (Google AI Studio) or "vertex_ai/gemini-*" (Vertex AI).
|
||||
pattern = r"^(gemini|vertex_ai)/gemini-"
|
||||
return bool(re.match(pattern, model_string))
|
||||
|
||||
|
||||
def _extract_gemini_model_from_litellm(litellm_model: str) -> str:
|
||||
"""Extract the pure Gemini model name from a LiteLLM model string.
|
||||
|
||||
Args:
|
||||
litellm_model: LiteLLM model string like "gemini/gemini-2.5-pro"
|
||||
|
||||
Returns:
|
||||
Pure Gemini model name like "gemini-2.5-pro"
|
||||
"""
|
||||
# Remove LiteLLM provider prefix
|
||||
if "/" in litellm_model:
|
||||
return litellm_model.split("/", 1)[1]
|
||||
return litellm_model
|
||||
|
||||
|
||||
def _warn_gemini_via_litellm(model_string: str) -> None:
|
||||
"""Warn if Gemini is being used via LiteLLM.
|
||||
|
||||
This function logs a warning suggesting users use Gemini directly rather than
|
||||
through LiteLLM for better performance and features.
|
||||
|
||||
Args:
|
||||
model_string: The LiteLLM model string to check
|
||||
"""
|
||||
if not _is_litellm_gemini_model(model_string):
|
||||
return
|
||||
|
||||
# Check if warning should be suppressed via environment variable
|
||||
if os.environ.get(
|
||||
"ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", ""
|
||||
).strip().lower() in ("1", "true", "yes", "on"):
|
||||
return
|
||||
|
||||
warnings.warn(
|
||||
f"[GEMINI_VIA_LITELLM] {model_string}: You are using Gemini via LiteLLM."
|
||||
" For better performance, reliability, and access to latest features,"
|
||||
" consider using Gemini directly through ADK's native Gemini"
|
||||
f" integration. Replace LiteLlm(model='{model_string}') with"
|
||||
f" Gemini(model='{_extract_gemini_model_from_litellm(model_string)}')."
|
||||
" Set ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS=true to suppress this"
|
||||
" warning.",
|
||||
category=UserWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
|
||||
class LiteLlm(BaseLlm):
|
||||
"""Wrapper around litellm.
|
||||
|
||||
@@ -708,6 +772,8 @@ class LiteLlm(BaseLlm):
|
||||
**kwargs: Additional arguments to pass to the litellm completion api.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
# Warn if using Gemini via LiteLLM
|
||||
_warn_gemini_via_litellm(model)
|
||||
self._additional_args = kwargs
|
||||
# preventing generation call with llm_client
|
||||
# and overriding messages, tools and stream which are managed internally
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import Mock
|
||||
import warnings
|
||||
|
||||
from google.adk.models.lite_llm import _content_to_message_param
|
||||
from google.adk.models.lite_llm import _function_declaration_to_tool_param
|
||||
@@ -1574,3 +1575,51 @@ def test_get_completion_inputs_generation_params():
|
||||
# Should not include max_output_tokens
|
||||
assert "max_output_tokens" not in generation_params
|
||||
assert "stop_sequences" not in generation_params
|
||||
|
||||
|
||||
def test_gemini_via_litellm_warning(monkeypatch):
|
||||
"""Test that Gemini via LiteLLM shows warning."""
|
||||
# Ensure environment variable is not set
|
||||
monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
# Test with Google AI Studio Gemini via LiteLLM
|
||||
LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25")
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, UserWarning)
|
||||
assert "[GEMINI_VIA_LITELLM]" in str(w[0].message)
|
||||
assert "better performance" in str(w[0].message)
|
||||
assert "gemini-2.5-pro-exp-03-25" in str(w[0].message)
|
||||
assert "ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS" in str(w[0].message)
|
||||
|
||||
|
||||
def test_gemini_via_litellm_warning_vertex_ai(monkeypatch):
|
||||
"""Test that Vertex AI Gemini via LiteLLM shows warning."""
|
||||
# Ensure environment variable is not set
|
||||
monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
# Test with Vertex AI Gemini via LiteLLM
|
||||
LiteLlm(model="vertex_ai/gemini-1.5-flash")
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, UserWarning)
|
||||
assert "[GEMINI_VIA_LITELLM]" in str(w[0].message)
|
||||
assert "vertex_ai/gemini-1.5-flash" in str(w[0].message)
|
||||
|
||||
|
||||
def test_gemini_via_litellm_warning_suppressed(monkeypatch):
|
||||
"""Test that Gemini via LiteLLM warning can be suppressed."""
|
||||
monkeypatch.setenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", "true")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25")
|
||||
assert len(w) == 0
|
||||
|
||||
|
||||
def test_non_gemini_litellm_no_warning():
|
||||
"""Test that non-Gemini models via LiteLLM don't show warning."""
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
# Test with non-Gemini model
|
||||
LiteLlm(model="openai/gpt-4o")
|
||||
assert len(w) == 0
|
||||
|
||||
Reference in New Issue
Block a user