From 9291daaa8e399ca052f5a52dbb600d719dcc9fa8 Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 29 Aug 2025 11:32:07 -0700 Subject: [PATCH] chore: Add warning for using Gemini models via LiteLLM Recommend to use Gemini outside of LiteLLM PiperOrigin-RevId: 800971705 --- src/google/adk/models/lite_llm.py | 66 ++++++++++++++++++++++++++ tests/unittests/models/test_litellm.py | 49 +++++++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index abed4cd3..d84df9ab 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -17,6 +17,8 @@ from __future__ import annotations import base64 import json import logging +import os +import re from typing import Any from typing import AsyncGenerator from typing import cast @@ -29,6 +31,7 @@ from typing import Optional from typing import Tuple from typing import TypedDict from typing import Union +import warnings from google.genai import types import litellm @@ -672,6 +675,67 @@ Functions: """ +def _is_litellm_gemini_model(model_string: str) -> bool: + """Check if the model is a Gemini model accessed via LiteLLM. + + Args: + model_string: A LiteLLM model string (e.g., "gemini/gemini-2.5-pro" or + "vertex_ai/gemini-1.5-flash") + + Returns: + True if it's a Gemini model accessed via LiteLLM, False otherwise + """ + # Matches "gemini/gemini-*" (Google AI Studio) or "vertex_ai/gemini-*" (Vertex AI). + pattern = r"^(gemini|vertex_ai)/gemini-" + return bool(re.match(pattern, model_string)) + + +def _extract_gemini_model_from_litellm(litellm_model: str) -> str: + """Extract the pure Gemini model name from a LiteLLM model string. + + Args: + litellm_model: LiteLLM model string like "gemini/gemini-2.5-pro" + + Returns: + Pure Gemini model name like "gemini-2.5-pro" + """ + # Remove LiteLLM provider prefix + if "/" in litellm_model: + return litellm_model.split("/", 1)[1] + return litellm_model + + +def _warn_gemini_via_litellm(model_string: str) -> None: + """Warn if Gemini is being used via LiteLLM. + + This function logs a warning suggesting users use Gemini directly rather than + through LiteLLM for better performance and features. + + Args: + model_string: The LiteLLM model string to check + """ + if not _is_litellm_gemini_model(model_string): + return + + # Check if warning should be suppressed via environment variable + if os.environ.get( + "ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", "" + ).strip().lower() in ("1", "true", "yes", "on"): + return + + warnings.warn( + f"[GEMINI_VIA_LITELLM] {model_string}: You are using Gemini via LiteLLM." + " For better performance, reliability, and access to latest features," + " consider using Gemini directly through ADK's native Gemini" + f" integration. Replace LiteLlm(model='{model_string}') with" + f" Gemini(model='{_extract_gemini_model_from_litellm(model_string)}')." + " Set ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS=true to suppress this" + " warning.", + category=UserWarning, + stacklevel=3, + ) + + class LiteLlm(BaseLlm): """Wrapper around litellm. @@ -708,6 +772,8 @@ class LiteLlm(BaseLlm): **kwargs: Additional arguments to pass to the litellm completion api. """ super().__init__(model=model, **kwargs) + # Warn if using Gemini via LiteLLM + _warn_gemini_via_litellm(model) self._additional_args = kwargs # preventing generation call with llm_client # and overriding messages, tools and stream which are managed internally diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index ad72d3c3..a7152f55 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -16,6 +16,7 @@ import json from unittest.mock import AsyncMock from unittest.mock import Mock +import warnings from google.adk.models.lite_llm import _content_to_message_param from google.adk.models.lite_llm import _function_declaration_to_tool_param @@ -1574,3 +1575,51 @@ def test_get_completion_inputs_generation_params(): # Should not include max_output_tokens assert "max_output_tokens" not in generation_params assert "stop_sequences" not in generation_params + + +def test_gemini_via_litellm_warning(monkeypatch): + """Test that Gemini via LiteLLM shows warning.""" + # Ensure environment variable is not set + monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Test with Google AI Studio Gemini via LiteLLM + LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25") + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[GEMINI_VIA_LITELLM]" in str(w[0].message) + assert "better performance" in str(w[0].message) + assert "gemini-2.5-pro-exp-03-25" in str(w[0].message) + assert "ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS" in str(w[0].message) + + +def test_gemini_via_litellm_warning_vertex_ai(monkeypatch): + """Test that Vertex AI Gemini via LiteLLM shows warning.""" + # Ensure environment variable is not set + monkeypatch.delenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", raising=False) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Test with Vertex AI Gemini via LiteLLM + LiteLlm(model="vertex_ai/gemini-1.5-flash") + assert len(w) == 1 + assert issubclass(w[0].category, UserWarning) + assert "[GEMINI_VIA_LITELLM]" in str(w[0].message) + assert "vertex_ai/gemini-1.5-flash" in str(w[0].message) + + +def test_gemini_via_litellm_warning_suppressed(monkeypatch): + """Test that Gemini via LiteLLM warning can be suppressed.""" + monkeypatch.setenv("ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", "true") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + LiteLlm(model="gemini/gemini-2.5-pro-exp-03-25") + assert len(w) == 0 + + +def test_non_gemini_litellm_no_warning(): + """Test that non-Gemini models via LiteLLM don't show warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Test with non-Gemini model + LiteLlm(model="openai/gpt-4o") + assert len(w) == 0