diff --git a/pyproject.toml b/pyproject.toml index e89c9656..1105e345 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,15 +127,16 @@ docs = [ # Optional extensions extensions = [ - "anthropic>=0.43.0", # For anthropic model support - "beautifulsoup4>=3.2.2", # For load_web_page tool. - "crewai[tools];python_version>='3.10'", # For CrewaiTool - "docker>=7.0.0", # For ContainerCodeExecutor - "langgraph>=0.2.60", # For LangGraphAgent - "litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it - "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. - "lxml>=5.3.0", # For load_web_page tool. - "toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset + "anthropic>=0.43.0", # For anthropic model support + "beautifulsoup4>=3.2.2", # For load_web_page tool. + "crewai[tools];python_version>='3.10'", # For CrewaiTool + "docker>=7.0.0", # For ContainerCodeExecutor + "langgraph>=0.2.60", # For LangGraphAgent + "litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it + "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. + "llama-index-embeddings-google-genai>=0.3.0",# For files retrieval using LlamaIndex. + "lxml>=5.3.0", # For load_web_page tool. + "toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset ] diff --git a/src/google/adk/tools/retrieval/files_retrieval.py b/src/google/adk/tools/retrieval/files_retrieval.py index 00142667..9f143678 100644 --- a/src/google/adk/tools/retrieval/files_retrieval.py +++ b/src/google/adk/tools/retrieval/files_retrieval.py @@ -17,23 +17,64 @@ from __future__ import annotations import logging +from typing import Optional from llama_index.core import SimpleDirectoryReader from llama_index.core import VectorStoreIndex +from llama_index.core.base.embeddings.base import BaseEmbedding from .llama_index_retrieval import LlamaIndexRetrieval logger = logging.getLogger("google_adk." + __name__) +def _get_default_embedding_model() -> BaseEmbedding: + """Get the default Google Gemini embedding model. + + Returns: + GoogleGenAIEmbedding instance configured with text-embedding-004 model. + + Raises: + ImportError: If llama-index-embeddings-google-genai package is not installed. + """ + try: + from llama_index.embeddings.google_genai import GoogleGenAIEmbedding + + return GoogleGenAIEmbedding(model_name="text-embedding-004") + except ImportError as e: + raise ImportError( + "llama-index-embeddings-google-genai package not found. " + "Please run: pip install llama-index-embeddings-google-genai" + ) from e + + class FilesRetrieval(LlamaIndexRetrieval): - def __init__(self, *, name: str, description: str, input_dir: str): + def __init__( + self, + *, + name: str, + description: str, + input_dir: str, + embedding_model: Optional[BaseEmbedding] = None, + ): + """Initialize FilesRetrieval with optional embedding model. + Args: + name: Name of the tool. + description: Description of the tool. + input_dir: Directory path containing files to index. + embedding_model: Optional custom embedding model. If None, defaults to + Google's text-embedding-004 model. + """ self.input_dir = input_dir + if embedding_model is None: + embedding_model = _get_default_embedding_model() + logger.info("Loading data from %s", input_dir) retriever = VectorStoreIndex.from_documents( - SimpleDirectoryReader(input_dir).load_data() + SimpleDirectoryReader(input_dir).load_data(), + embed_model=embedding_model, ).as_retriever() super().__init__(name=name, description=description, retriever=retriever) diff --git a/tests/unittests/tools/retrieval/test_files_retrieval.py b/tests/unittests/tools/retrieval/test_files_retrieval.py new file mode 100644 index 00000000..ea4b99cd --- /dev/null +++ b/tests/unittests/tools/retrieval/test_files_retrieval.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FilesRetrieval tool.""" + +import sys +import unittest.mock as mock + +from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model +from google.adk.tools.retrieval.files_retrieval import FilesRetrieval +from llama_index.core.base.embeddings.base import BaseEmbedding +import pytest + + +class MockEmbedding(BaseEmbedding): + """Mock embedding model for testing.""" + + def _get_query_embedding(self, query): + return [0.1] * 384 + + def _get_text_embedding(self, text): + return [0.1] * 384 + + async def _aget_query_embedding(self, query): + return [0.1] * 384 + + async def _aget_text_embedding(self, text): + return [0.1] * 384 + + +class TestFilesRetrieval: + + def test_files_retrieval_with_custom_embedding(self, tmp_path): + """Test FilesRetrieval with custom embedding model.""" + # Create test file + test_file = tmp_path / "test.txt" + test_file.write_text("This is a test document for retrieval testing.") + + custom_embedding = MockEmbedding() + retrieval = FilesRetrieval( + name="test_retrieval", + description="Test retrieval tool", + input_dir=str(tmp_path), + embedding_model=custom_embedding, + ) + + assert retrieval.name == "test_retrieval" + assert retrieval.input_dir == str(tmp_path) + assert retrieval.retriever is not None + + @mock.patch( + "google.adk.tools.retrieval.files_retrieval._get_default_embedding_model" + ) + def test_files_retrieval_uses_default_embedding( + self, mock_get_default_embedding, tmp_path + ): + """Test FilesRetrieval uses default embedding when none provided.""" + # Create test file + test_file = tmp_path / "test.txt" + test_file.write_text("This is a test document for retrieval testing.") + + mock_embedding = MockEmbedding() + mock_get_default_embedding.return_value = mock_embedding + + retrieval = FilesRetrieval( + name="test_retrieval", + description="Test retrieval tool", + input_dir=str(tmp_path), + ) + + mock_get_default_embedding.assert_called_once() + assert retrieval.name == "test_retrieval" + assert retrieval.input_dir == str(tmp_path) + + def test_get_default_embedding_model_import_error(self): + """Test _get_default_embedding_model handles ImportError correctly.""" + # Simulate the package not being installed by making import fail + import builtins + + original_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == "llama_index.embeddings.google_genai": + raise ImportError( + "No module named 'llama_index.embeddings.google_genai'" + ) + return original_import(name, *args, **kwargs) + + with mock.patch("builtins.__import__", side_effect=mock_import): + with pytest.raises(ImportError) as exc_info: + _get_default_embedding_model() + + # The exception should be re-raised as our custom ImportError with helpful message + assert "llama-index-embeddings-google-genai package not found" in str( + exc_info.value + ) + assert "pip install llama-index-embeddings-google-genai" in str( + exc_info.value + ) + + def test_get_default_embedding_model_success(self): + """Test _get_default_embedding_model returns Google embedding when available.""" + # Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available + if sys.version_info < (3, 10): + pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+") + + # Mock the module creation to avoid import issues + mock_module = mock.MagicMock() + mock_embedding_instance = MockEmbedding() + mock_module.GoogleGenAIEmbedding.return_value = mock_embedding_instance + + with mock.patch.dict( + "sys.modules", {"llama_index.embeddings.google_genai": mock_module} + ): + result = _get_default_embedding_model() + + mock_module.GoogleGenAIEmbedding.assert_called_once_with( + model_name="text-embedding-004" + ) + assert result == mock_embedding_instance + + def test_backward_compatibility(self, tmp_path): + """Test that existing code without embedding_model parameter still works.""" + # Create test file + test_file = tmp_path / "test.txt" + test_file.write_text("This is a test document for retrieval testing.") + + with mock.patch( + "google.adk.tools.retrieval.files_retrieval._get_default_embedding_model" + ) as mock_get_default: + mock_get_default.return_value = MockEmbedding() + + # This should work exactly like before - no embedding_model parameter + retrieval = FilesRetrieval( + name="test_retrieval", + description="Test retrieval tool", + input_dir=str(tmp_path), + ) + + assert retrieval.name == "test_retrieval" + assert retrieval.input_dir == str(tmp_path) + mock_get_default.assert_called_once()