diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 9f3c2a2c..1be0cc69 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -33,3 +33,23 @@ __all__ = [ LLMRegistry.register(Gemini) LLMRegistry.register(Gemma) LLMRegistry.register(ApigeeLlm) + +# Optionally register Claude if anthropic package is installed +try: + from .anthropic_llm import Claude + + LLMRegistry.register(Claude) + __all__.append('Claude') +except Exception: + # Claude support requires: pip install google-adk[extensions] + pass + +# Optionally register LiteLlm if litellm package is installed +try: + from .lite_llm import LiteLlm + + LLMRegistry.register(LiteLlm) + __all__.append('LiteLlm') +except Exception: + # LiteLLM support requires: pip install google-adk[extensions] + pass diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 9e3698b1..162db059 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1388,11 +1388,19 @@ class LiteLlm(BaseLlm): def supported_models(cls) -> list[str]: """Provides the list of supported models. - LiteLlm supports all models supported by litellm. We do not keep track of - these models here. So we return an empty list. + This registers common provider prefixes. LiteLlm can handle many more, + but these patterns activate the integration for the most common use cases. + See https://docs.litellm.ai/docs/providers for a full list. Returns: A list of supported models. """ - return [] + return [ + # For OpenAI models (e.g., "openai/gpt-4o") + r"openai/.*", + # For Groq models via Groq API (e.g., "groq/llama3-70b-8192") + r"groq/.*", + # For Anthropic models (e.g., "anthropic/claude-3-opus-20240229") + r"anthropic/.*", + ] diff --git a/src/google/adk/models/registry.py b/src/google/adk/models/registry.py index 22e24d4c..852996ff 100644 --- a/src/google/adk/models/registry.py +++ b/src/google/adk/models/registry.py @@ -99,4 +99,26 @@ class LLMRegistry: if re.compile(regex).fullmatch(model): return llm_class - raise ValueError(f'Model {model} not found.') + # Provide helpful error messages for known patterns + error_msg = f'Model {model} not found.' + + # Check if it matches known patterns that require optional dependencies + if re.match(r'^claude-', model): + error_msg += ( + '\n\nClaude models require the anthropic package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install anthropic>=0.43.0' + ) + elif '/' in model: + # Any model with provider/model format likely needs LiteLLM + error_msg += ( + '\n\nProvider-style models (e.g., "provider/model-name") require' + ' the litellm package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install litellm>=1.75.5' + '\n\nSupported providers include: openai, groq, anthropic, and 100+' + ' others.' + '\nSee https://docs.litellm.ai/docs/providers for a full list.' + ) + + raise ValueError(error_msg) diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index c57254db..577923f7 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -22,6 +22,9 @@ from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.anthropic_llm import Claude +from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.models.llm_request import LlmRequest from google.adk.models.registry import LLMRegistry from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -411,3 +414,47 @@ class TestCanonicalTools: assert len(tools) == 1 assert tools[0].name == 'vertex_ai_search' assert tools[0].__class__.__name__ == 'VertexAiSearchTool' + + +# Tests for multi-provider model support via string model names +@pytest.mark.parametrize( + 'model_name', + [ + 'gemini-1.5-flash', + 'gemini-2.0-flash-exp', + ], +) +def test_agent_with_gemini_string_model(model_name): + """Test that Agent accepts Gemini model strings and resolves to Gemini.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Gemini) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-sonnet-4@20250514', + ], +) +def test_agent_with_claude_string_model(model_name): + """Test that Agent accepts Claude model strings and resolves to Claude.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Claude) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'groq/llama3-70b-8192', + 'anthropic/claude-3-opus-20240229', + ], +) +def test_agent_with_litellm_string_model(model_name): + """Test that Agent accepts LiteLLM provider strings.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, LiteLlm) + assert agent.canonical_model.model == model_name diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index 70246c7b..8575064b 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -15,7 +15,7 @@ from google.adk import models from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini -from google.adk.models.registry import LLMRegistry +from google.adk.models.lite_llm import LiteLlm import pytest @@ -34,6 +34,7 @@ import pytest ], ) def test_match_gemini_family(model_name): + """Test that Gemini models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Gemini @@ -51,12 +52,63 @@ def test_match_gemini_family(model_name): ], ) def test_match_claude_family(model_name): - LLMRegistry.register(Claude) - + """Test that Claude models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Claude +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'openai/gpt-4o-mini', + 'groq/llama3-70b-8192', + 'groq/mixtral-8x7b-32768', + 'anthropic/claude-3-opus-20240229', + 'anthropic/claude-3-5-sonnet-20241022', + ], +) +def test_match_litellm_family(model_name): + """Test that LiteLLM models are resolved correctly.""" + assert models.LLMRegistry.resolve(model_name) is LiteLlm + + def test_non_exist_model(): with pytest.raises(ValueError) as e_info: models.LLMRegistry.resolve('non-exist-model') assert 'Model non-exist-model not found.' in str(e_info.value) + + +def test_helpful_error_for_claude_without_extensions(): + """Test that missing Claude models show helpful install instructions. + + Note: This test may pass even when anthropic IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent Claude model variant to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('claude-nonexistent-model-xyz') + + error_msg = str(e_info.value) + # The error should mention anthropic package and installation instructions + # These checks work whether or not anthropic is actually installed + assert 'Model claude-nonexistent-model-xyz not found' in error_msg + assert 'anthropic package' in error_msg + assert 'pip install' in error_msg + + +def test_helpful_error_for_litellm_without_extensions(): + """Test that missing LiteLLM models show helpful install instructions. + + Note: This test may pass even when litellm IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent provider to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('unknown-provider/gpt-4o') + + error_msg = str(e_info.value) + # The error should mention litellm package for provider-style models + assert 'Model unknown-provider/gpt-4o not found' in error_msg + assert 'litellm package' in error_msg + assert 'pip install' in error_msg + assert 'Provider-style models' in error_msg