You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Allow user specify embedding model for file retrieval
And use Gemini embedding model as default model if no embedding model is specified. PiperOrigin-RevId: 801161505
This commit is contained in:
committed by
Copybara-Service
parent
0c87907bcb
commit
67f23df25a
+10
-9
@@ -127,15 +127,16 @@ docs = [
|
||||
|
||||
# Optional extensions
|
||||
extensions = [
|
||||
"anthropic>=0.43.0", # For anthropic model support
|
||||
"beautifulsoup4>=3.2.2", # For load_web_page tool.
|
||||
"crewai[tools];python_version>='3.10'", # For CrewaiTool
|
||||
"docker>=7.0.0", # For ContainerCodeExecutor
|
||||
"langgraph>=0.2.60", # For LangGraphAgent
|
||||
"litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
|
||||
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
|
||||
"lxml>=5.3.0", # For load_web_page tool.
|
||||
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
|
||||
"anthropic>=0.43.0", # For anthropic model support
|
||||
"beautifulsoup4>=3.2.2", # For load_web_page tool.
|
||||
"crewai[tools];python_version>='3.10'", # For CrewaiTool
|
||||
"docker>=7.0.0", # For ContainerCodeExecutor
|
||||
"langgraph>=0.2.60", # For LangGraphAgent
|
||||
"litellm>=1.75.5", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
|
||||
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
|
||||
"llama-index-embeddings-google-genai>=0.3.0",# For files retrieval using LlamaIndex.
|
||||
"lxml>=5.3.0", # For load_web_page tool.
|
||||
"toolbox-core>=0.1.0", # For tools.toolbox_toolset.ToolboxToolset
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -17,23 +17,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from .llama_index_retrieval import LlamaIndexRetrieval
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
def _get_default_embedding_model() -> BaseEmbedding:
|
||||
"""Get the default Google Gemini embedding model.
|
||||
|
||||
Returns:
|
||||
GoogleGenAIEmbedding instance configured with text-embedding-004 model.
|
||||
|
||||
Raises:
|
||||
ImportError: If llama-index-embeddings-google-genai package is not installed.
|
||||
"""
|
||||
try:
|
||||
from llama_index.embeddings.google_genai import GoogleGenAIEmbedding
|
||||
|
||||
return GoogleGenAIEmbedding(model_name="text-embedding-004")
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"llama-index-embeddings-google-genai package not found. "
|
||||
"Please run: pip install llama-index-embeddings-google-genai"
|
||||
) from e
|
||||
|
||||
|
||||
class FilesRetrieval(LlamaIndexRetrieval):
|
||||
|
||||
def __init__(self, *, name: str, description: str, input_dir: str):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str,
|
||||
input_dir: str,
|
||||
embedding_model: Optional[BaseEmbedding] = None,
|
||||
):
|
||||
"""Initialize FilesRetrieval with optional embedding model.
|
||||
|
||||
Args:
|
||||
name: Name of the tool.
|
||||
description: Description of the tool.
|
||||
input_dir: Directory path containing files to index.
|
||||
embedding_model: Optional custom embedding model. If None, defaults to
|
||||
Google's text-embedding-004 model.
|
||||
"""
|
||||
self.input_dir = input_dir
|
||||
|
||||
if embedding_model is None:
|
||||
embedding_model = _get_default_embedding_model()
|
||||
|
||||
logger.info("Loading data from %s", input_dir)
|
||||
retriever = VectorStoreIndex.from_documents(
|
||||
SimpleDirectoryReader(input_dir).load_data()
|
||||
SimpleDirectoryReader(input_dir).load_data(),
|
||||
embed_model=embedding_model,
|
||||
).as_retriever()
|
||||
super().__init__(name=name, description=description, retriever=retriever)
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for FilesRetrieval tool."""
|
||||
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
|
||||
from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model
|
||||
from google.adk.tools.retrieval.files_retrieval import FilesRetrieval
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
import pytest
|
||||
|
||||
|
||||
class MockEmbedding(BaseEmbedding):
|
||||
"""Mock embedding model for testing."""
|
||||
|
||||
def _get_query_embedding(self, query):
|
||||
return [0.1] * 384
|
||||
|
||||
def _get_text_embedding(self, text):
|
||||
return [0.1] * 384
|
||||
|
||||
async def _aget_query_embedding(self, query):
|
||||
return [0.1] * 384
|
||||
|
||||
async def _aget_text_embedding(self, text):
|
||||
return [0.1] * 384
|
||||
|
||||
|
||||
class TestFilesRetrieval:
|
||||
|
||||
def test_files_retrieval_with_custom_embedding(self, tmp_path):
|
||||
"""Test FilesRetrieval with custom embedding model."""
|
||||
# Create test file
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("This is a test document for retrieval testing.")
|
||||
|
||||
custom_embedding = MockEmbedding()
|
||||
retrieval = FilesRetrieval(
|
||||
name="test_retrieval",
|
||||
description="Test retrieval tool",
|
||||
input_dir=str(tmp_path),
|
||||
embedding_model=custom_embedding,
|
||||
)
|
||||
|
||||
assert retrieval.name == "test_retrieval"
|
||||
assert retrieval.input_dir == str(tmp_path)
|
||||
assert retrieval.retriever is not None
|
||||
|
||||
@mock.patch(
|
||||
"google.adk.tools.retrieval.files_retrieval._get_default_embedding_model"
|
||||
)
|
||||
def test_files_retrieval_uses_default_embedding(
|
||||
self, mock_get_default_embedding, tmp_path
|
||||
):
|
||||
"""Test FilesRetrieval uses default embedding when none provided."""
|
||||
# Create test file
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("This is a test document for retrieval testing.")
|
||||
|
||||
mock_embedding = MockEmbedding()
|
||||
mock_get_default_embedding.return_value = mock_embedding
|
||||
|
||||
retrieval = FilesRetrieval(
|
||||
name="test_retrieval",
|
||||
description="Test retrieval tool",
|
||||
input_dir=str(tmp_path),
|
||||
)
|
||||
|
||||
mock_get_default_embedding.assert_called_once()
|
||||
assert retrieval.name == "test_retrieval"
|
||||
assert retrieval.input_dir == str(tmp_path)
|
||||
|
||||
def test_get_default_embedding_model_import_error(self):
|
||||
"""Test _get_default_embedding_model handles ImportError correctly."""
|
||||
# Simulate the package not being installed by making import fail
|
||||
import builtins
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "llama_index.embeddings.google_genai":
|
||||
raise ImportError(
|
||||
"No module named 'llama_index.embeddings.google_genai'"
|
||||
)
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with mock.patch("builtins.__import__", side_effect=mock_import):
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
_get_default_embedding_model()
|
||||
|
||||
# The exception should be re-raised as our custom ImportError with helpful message
|
||||
assert "llama-index-embeddings-google-genai package not found" in str(
|
||||
exc_info.value
|
||||
)
|
||||
assert "pip install llama-index-embeddings-google-genai" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
def test_get_default_embedding_model_success(self):
|
||||
"""Test _get_default_embedding_model returns Google embedding when available."""
|
||||
# Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available
|
||||
if sys.version_info < (3, 10):
|
||||
pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+")
|
||||
|
||||
# Mock the module creation to avoid import issues
|
||||
mock_module = mock.MagicMock()
|
||||
mock_embedding_instance = MockEmbedding()
|
||||
mock_module.GoogleGenAIEmbedding.return_value = mock_embedding_instance
|
||||
|
||||
with mock.patch.dict(
|
||||
"sys.modules", {"llama_index.embeddings.google_genai": mock_module}
|
||||
):
|
||||
result = _get_default_embedding_model()
|
||||
|
||||
mock_module.GoogleGenAIEmbedding.assert_called_once_with(
|
||||
model_name="text-embedding-004"
|
||||
)
|
||||
assert result == mock_embedding_instance
|
||||
|
||||
def test_backward_compatibility(self, tmp_path):
|
||||
"""Test that existing code without embedding_model parameter still works."""
|
||||
# Create test file
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("This is a test document for retrieval testing.")
|
||||
|
||||
with mock.patch(
|
||||
"google.adk.tools.retrieval.files_retrieval._get_default_embedding_model"
|
||||
) as mock_get_default:
|
||||
mock_get_default.return_value = MockEmbedding()
|
||||
|
||||
# This should work exactly like before - no embedding_model parameter
|
||||
retrieval = FilesRetrieval(
|
||||
name="test_retrieval",
|
||||
description="Test retrieval tool",
|
||||
input_dir=str(tmp_path),
|
||||
)
|
||||
|
||||
assert retrieval.name == "test_retrieval"
|
||||
assert retrieval.input_dir == str(tmp_path)
|
||||
mock_get_default.assert_called_once()
|
||||
Reference in New Issue
Block a user