feat: Allow user specify embedding model for file retrieval

And use Gemini embedding model as default model if no embedding model is specified.

PiperOrigin-RevId: 801161505
This commit is contained in:
Xiang (Sean) Zhou
2025-08-29 23:34:50 -07:00
committed by Copybara-Service
parent 0c87907bcb
commit 67f23df25a
3 changed files with 206 additions and 11 deletions
+10 -9
View File
@@ -127,15 +127,16 @@ docs = [
# Optional extensions
extensions = [
"anthropic>=0.43.0", # For anthropic model support
"beautifulsoup4>=3.2.2", # For load_web_page tool.
"crewai[tools];python_version>='3.10'", # For CrewaiTool
"docker>=7.0.0", # For ContainerCodeExecutor
"langgraph>=0.2.60", # For LangGraphAgent
"litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
"lxml>=5.3.0", # For load_web_page tool.
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
"anthropic>=0.43.0", # For anthropic model support
"beautifulsoup4>=3.2.2", # For load_web_page tool.
"crewai[tools];python_version>='3.10'", # For CrewaiTool
"docker>=7.0.0", # For ContainerCodeExecutor
"langgraph>=0.2.60", # For LangGraphAgent
"litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
"llama-index-embeddings-google-genai>=0.3.0",# For files retrieval using LlamaIndex.
"lxml>=5.3.0", # For load_web_page tool.
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
]
@@ -17,23 +17,64 @@
from __future__ import annotations
import logging
from typing import Optional
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex
from llama_index.core.base.embeddings.base import BaseEmbedding
from .llama_index_retrieval import LlamaIndexRetrieval
logger = logging.getLogger("google_adk." + __name__)
def _get_default_embedding_model() -> BaseEmbedding:
"""Get the default Google Gemini embedding model.
Returns:
GoogleGenAIEmbedding instance configured with text-embedding-004 model.
Raises:
ImportError: If llama-index-embeddings-google-genai package is not installed.
"""
try:
from llama_index.embeddings.google_genai import GoogleGenAIEmbedding
return GoogleGenAIEmbedding(model_name="text-embedding-004")
except ImportError as e:
raise ImportError(
"llama-index-embeddings-google-genai package not found. "
"Please run: pip install llama-index-embeddings-google-genai"
) from e
class FilesRetrieval(LlamaIndexRetrieval):
def __init__(self, *, name: str, description: str, input_dir: str):
def __init__(
self,
*,
name: str,
description: str,
input_dir: str,
embedding_model: Optional[BaseEmbedding] = None,
):
"""Initialize FilesRetrieval with optional embedding model.
Args:
name: Name of the tool.
description: Description of the tool.
input_dir: Directory path containing files to index.
embedding_model: Optional custom embedding model. If None, defaults to
Google's text-embedding-004 model.
"""
self.input_dir = input_dir
if embedding_model is None:
embedding_model = _get_default_embedding_model()
logger.info("Loading data from %s", input_dir)
retriever = VectorStoreIndex.from_documents(
SimpleDirectoryReader(input_dir).load_data()
SimpleDirectoryReader(input_dir).load_data(),
embed_model=embedding_model,
).as_retriever()
super().__init__(name=name, description=description, retriever=retriever)
@@ -0,0 +1,153 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for FilesRetrieval tool."""
import sys
import unittest.mock as mock
from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model
from google.adk.tools.retrieval.files_retrieval import FilesRetrieval
from llama_index.core.base.embeddings.base import BaseEmbedding
import pytest
class MockEmbedding(BaseEmbedding):
"""Mock embedding model for testing."""
def _get_query_embedding(self, query):
return [0.1] * 384
def _get_text_embedding(self, text):
return [0.1] * 384
async def _aget_query_embedding(self, query):
return [0.1] * 384
async def _aget_text_embedding(self, text):
return [0.1] * 384
class TestFilesRetrieval:
def test_files_retrieval_with_custom_embedding(self, tmp_path):
"""Test FilesRetrieval with custom embedding model."""
# Create test file
test_file = tmp_path / "test.txt"
test_file.write_text("This is a test document for retrieval testing.")
custom_embedding = MockEmbedding()
retrieval = FilesRetrieval(
name="test_retrieval",
description="Test retrieval tool",
input_dir=str(tmp_path),
embedding_model=custom_embedding,
)
assert retrieval.name == "test_retrieval"
assert retrieval.input_dir == str(tmp_path)
assert retrieval.retriever is not None
@mock.patch(
"google.adk.tools.retrieval.files_retrieval._get_default_embedding_model"
)
def test_files_retrieval_uses_default_embedding(
self, mock_get_default_embedding, tmp_path
):
"""Test FilesRetrieval uses default embedding when none provided."""
# Create test file
test_file = tmp_path / "test.txt"
test_file.write_text("This is a test document for retrieval testing.")
mock_embedding = MockEmbedding()
mock_get_default_embedding.return_value = mock_embedding
retrieval = FilesRetrieval(
name="test_retrieval",
description="Test retrieval tool",
input_dir=str(tmp_path),
)
mock_get_default_embedding.assert_called_once()
assert retrieval.name == "test_retrieval"
assert retrieval.input_dir == str(tmp_path)
def test_get_default_embedding_model_import_error(self):
"""Test _get_default_embedding_model handles ImportError correctly."""
# Simulate the package not being installed by making import fail
import builtins
original_import = builtins.__import__
def mock_import(name, *args, **kwargs):
if name == "llama_index.embeddings.google_genai":
raise ImportError(
"No module named 'llama_index.embeddings.google_genai'"
)
return original_import(name, *args, **kwargs)
with mock.patch("builtins.__import__", side_effect=mock_import):
with pytest.raises(ImportError) as exc_info:
_get_default_embedding_model()
# The exception should be re-raised as our custom ImportError with helpful message
assert "llama-index-embeddings-google-genai package not found" in str(
exc_info.value
)
assert "pip install llama-index-embeddings-google-genai" in str(
exc_info.value
)
def test_get_default_embedding_model_success(self):
"""Test _get_default_embedding_model returns Google embedding when available."""
# Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available
if sys.version_info < (3, 10):
pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+")
# Mock the module creation to avoid import issues
mock_module = mock.MagicMock()
mock_embedding_instance = MockEmbedding()
mock_module.GoogleGenAIEmbedding.return_value = mock_embedding_instance
with mock.patch.dict(
"sys.modules", {"llama_index.embeddings.google_genai": mock_module}
):
result = _get_default_embedding_model()
mock_module.GoogleGenAIEmbedding.assert_called_once_with(
model_name="text-embedding-004"
)
assert result == mock_embedding_instance
def test_backward_compatibility(self, tmp_path):
"""Test that existing code without embedding_model parameter still works."""
# Create test file
test_file = tmp_path / "test.txt"
test_file.write_text("This is a test document for retrieval testing.")
with mock.patch(
"google.adk.tools.retrieval.files_retrieval._get_default_embedding_model"
) as mock_get_default:
mock_get_default.return_value = MockEmbedding()
# This should work exactly like before - no embedding_model parameter
retrieval = FilesRetrieval(
name="test_retrieval",
description="Test retrieval tool",
input_dir=str(tmp_path),
)
assert retrieval.name == "test_retrieval"
assert retrieval.input_dir == str(tmp_path)
mock_get_default.assert_called_once()