You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: #3036 parameter filtering for CrewAI functions with **kwargs
Merge https://github.com/google/adk-python/pull/3037 fix: [#3036](https://github.com/google/adk-python/issues/3036) - Fix FunctionTool parameter filtering to support CrewAI-style tools - Functions with **kwargs now receive all parameters except 'self' and 'tool_context' - Maintains backward compatibility with explicit parameter functions - Add comprehensive tests for **kwargs functionality Fixes parameter filtering issue where CrewAI tools using **kwargs pattern would receive empty parameter dictionaries, causing search_query and other parameters to be None. #non-breaking COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3037 from omarcevi:fix/function-tool-kwargs-parameter-filtering 012bbfcfd68e83a29635ac74718a1bd1323c5187 PiperOrigin-RevId: 825275686
This commit is contained in:
committed by
Copybara-Service
parent
15afbcd158
commit
74a3500fc5
@@ -0,0 +1,160 @@
|
||||
# CrewAI Tool **kwargs Parameter Handling
|
||||
|
||||
This sample demonstrates how `CrewaiTool` correctly handles tools with
|
||||
`**kwargs` parameters, which is a common pattern in CrewAI tools.
|
||||
|
||||
## What This Sample Demonstrates
|
||||
|
||||
### Key Feature: **kwargs Parameter Passing
|
||||
|
||||
CrewAI tools often accept arbitrary parameters via `**kwargs`:
|
||||
|
||||
```python
|
||||
def _run(self, query: str, **kwargs) -> str:
|
||||
# Extra parameters are passed through kwargs
|
||||
category = kwargs.get('category')
|
||||
date_range = kwargs.get('date_range')
|
||||
limit = kwargs.get('limit')
|
||||
```
|
||||
|
||||
The `CrewaiTool` wrapper detects this pattern and passes all parameters through
|
||||
(except framework-managed ones like `self` and `tool_context`).
|
||||
|
||||
### Contrast with Regular Tools
|
||||
|
||||
For comparison, tools without `**kwargs` only accept explicitly declared
|
||||
parameters:
|
||||
|
||||
```python
|
||||
def _run(self, query: str, category: str) -> str:
|
||||
```
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Required: CrewAI Tools (Python 3.10+)
|
||||
|
||||
```bash
|
||||
pip install 'crewai-tools>=0.2.0'
|
||||
```
|
||||
|
||||
### Required: API Key
|
||||
|
||||
```bash
|
||||
export GOOGLE_API_KEY="your-api-key-here"
|
||||
# OR
|
||||
export GOOGLE_GENAI_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
## Running the Sample
|
||||
|
||||
### Option 1: Run the Happy Path Test
|
||||
|
||||
```bash
|
||||
cd contributing/samples/crewai_tool_kwargs
|
||||
python main.py
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
============================================================
|
||||
CrewAI Tool **kwargs Parameter Test
|
||||
============================================================
|
||||
|
||||
🧪 Test 1: Basic search (no extra parameters)
|
||||
User: Search for Python tutorials
|
||||
Agent: [Uses tool and returns results]
|
||||
|
||||
🧪 Test 2: Search with filters (**kwargs test)
|
||||
User: Search for machine learning articles, filtered by...
|
||||
Agent: [Uses tool with category, date_range, and limit parameters]
|
||||
|
||||
============================================================
|
||||
✅ Happy path test completed successfully!
|
||||
============================================================
|
||||
```
|
||||
|
||||
## What Gets Tested
|
||||
|
||||
✅ **CrewAI tool integration** - Wrapping a CrewAI BaseTool with ADK
|
||||
✅ **Basic parameters** - Required `query` parameter passes correctly
|
||||
✅ ****kwargs passing** - Extra parameters (category, date_range, limit) pass
|
||||
through
|
||||
✅ **End-to-end execution** - Tool executes and returns results to agent
|
||||
|
||||
## Code Structure
|
||||
|
||||
```
|
||||
crewai_tool_kwargs/
|
||||
├── __init__.py # Module initialization
|
||||
├── agent.py # Agent with CrewAI tool
|
||||
├── main.py # Happy path test
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
### Key Files
|
||||
|
||||
**agent.py:**
|
||||
|
||||
- Defines `CustomSearchTool` (CrewAI BaseTool with **kwargs)
|
||||
- Wraps it with `CrewaiTool`
|
||||
- Creates agent with the wrapped tool
|
||||
|
||||
**main.py:**
|
||||
|
||||
- Test 1: Basic search (no extra params)
|
||||
- Test 2: Search with filters (tests **kwargs)
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **CrewAI Tool Definition** (`agent.py`):
|
||||
```python
|
||||
class CustomSearchTool(BaseTool):
|
||||
def _run(self, query: str, **kwargs) -> str:
|
||||
# kwargs receives: category, date_range, limit, etc.
|
||||
```
|
||||
|
||||
2. **ADK Wrapping** (`agent.py`):
|
||||
```python
|
||||
adk_search_tool = CrewaiTool(
|
||||
crewai_search_tool,
|
||||
name="search_with_filters",
|
||||
description="..."
|
||||
)
|
||||
```
|
||||
|
||||
3. **LLM Function Calling** (`main.py`):
|
||||
- LLM sees the tool in function calling format
|
||||
- LLM calls with: `{query: "...", category: "...", date_range: "...", limit: 10}`
|
||||
- CrewaiTool passes ALL parameters to `**kwargs`
|
||||
|
||||
4. **Tool Execution**:
|
||||
- `query` → positional parameter
|
||||
- `category`, `date_range`, `limit` → collected in `**kwargs`
|
||||
- Tool logic uses all parameters
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### ImportError: No module named 'crewai'
|
||||
|
||||
```bash
|
||||
pip install 'crewai-tools>=0.2.0'
|
||||
```
|
||||
|
||||
### Python Version Error
|
||||
|
||||
CrewAI requires Python 3.10+:
|
||||
|
||||
```bash
|
||||
python --version # Should be 3.10 or higher
|
||||
```
|
||||
|
||||
### Missing API Key
|
||||
|
||||
```bash
|
||||
export GOOGLE_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
## Related
|
||||
|
||||
- Parent class: `FunctionTool` - Base class for all function-based tools
|
||||
- Unit tests: `tests/unittests/tools/test_crewai_tool.py`
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sample demonstrating CrewAI tool with **kwargs parameter handling.
|
||||
|
||||
This sample shows how CrewaiTool correctly passes arbitrary parameters
|
||||
through **kwargs, which is a common pattern in CrewAI tools.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from google.adk import Agent
|
||||
from google.adk.tools.crewai_tool import CrewaiTool
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class SearchInput(BaseModel):
|
||||
"""Input schema for the search tool."""
|
||||
|
||||
query: str = Field(..., description="The search query string")
|
||||
category: Optional[str] = Field(
|
||||
None, description="Filter by category (e.g., 'technology', 'science')"
|
||||
)
|
||||
date_range: Optional[str] = Field(
|
||||
None, description="Filter by date range (e.g., 'last_week', '2024')"
|
||||
)
|
||||
limit: Optional[int] = Field(
|
||||
None, description="Limit the number of results (e.g., 10, 20)"
|
||||
)
|
||||
|
||||
|
||||
class CustomSearchTool(BaseTool):
|
||||
"""A custom CrewAI tool that accepts arbitrary search parameters via **kwargs.
|
||||
|
||||
This demonstrates the key CrewAI tool pattern where tools accept
|
||||
flexible parameters through **kwargs.
|
||||
"""
|
||||
|
||||
name: str = "custom_search"
|
||||
description: str = (
|
||||
"Search for information with flexible filtering options. "
|
||||
"Accepts a query and optional filter parameters like category, "
|
||||
"date_range, limit, etc."
|
||||
)
|
||||
args_schema: type[BaseModel] = SearchInput
|
||||
|
||||
def _run(self, query: str, **kwargs) -> str:
|
||||
"""Execute search with arbitrary filter parameters.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
**kwargs: Additional filter parameters like category, date_range, limit.
|
||||
|
||||
Returns:
|
||||
A formatted string showing the query and applied filters.
|
||||
"""
|
||||
result_parts = [f"Searching for: '{query}'"]
|
||||
|
||||
if kwargs:
|
||||
result_parts.append("Applied filters:")
|
||||
for key, value in kwargs.items():
|
||||
result_parts.append(f" - {key}: {value}")
|
||||
else:
|
||||
result_parts.append("No additional filters applied.")
|
||||
|
||||
# Simulate search results
|
||||
result_parts.append(f"\nFound 3 results matching your criteria.")
|
||||
|
||||
return "\n".join(result_parts)
|
||||
|
||||
|
||||
crewai_search_tool = CustomSearchTool()
|
||||
|
||||
# Wrap it with ADK's CrewaiTool
|
||||
adk_search_tool = CrewaiTool(
|
||||
crewai_search_tool,
|
||||
name="search_with_filters",
|
||||
description=(
|
||||
"Search for information with optional filters like category, "
|
||||
"date_range, or limit"
|
||||
),
|
||||
)
|
||||
|
||||
root_agent = Agent(
|
||||
model="gemini-2.0-flash",
|
||||
name="search_agent",
|
||||
description="An agent that can search with flexible filtering options",
|
||||
instruction="""
|
||||
You are a helpful search assistant.
|
||||
When users ask you to search, use the search_with_filters tool.
|
||||
You can pass additional parameters like:
|
||||
- category: to filter by category (e.g., "technology", "science")
|
||||
- date_range: to filter by date (e.g., "last_week", "2024")
|
||||
- limit: to limit the number of results (e.g., 10, 20)
|
||||
|
||||
Always acknowledge what filters you're applying.
|
||||
""",
|
||||
tools=[adk_search_tool],
|
||||
)
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Happy path test for CrewAI tool with **kwargs parameter handling.
|
||||
|
||||
This demonstrates that CrewaiTool correctly passes arbitrary parameters
|
||||
through **kwargs to the underlying CrewAI tool.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import agent
|
||||
from dotenv import load_dotenv
|
||||
from google.adk.cli.utils import logs
|
||||
from google.adk.runners import InMemoryRunner
|
||||
from google.genai import types
|
||||
|
||||
load_dotenv(override=True)
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run happy path test demonstrating **kwargs parameter passing."""
|
||||
app_name = "crewai_kwargs_test"
|
||||
user_id = "test_user"
|
||||
|
||||
runner = InMemoryRunner(
|
||||
agent=agent.root_agent,
|
||||
app_name=app_name,
|
||||
)
|
||||
|
||||
session = await runner.session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("CrewAI Tool **kwargs Parameter Test")
|
||||
print("=" * 60)
|
||||
|
||||
# Test 1: Simple search without extra parameters
|
||||
print("\n🧪 Test 1: Basic search (no extra parameters)")
|
||||
print("-" * 60)
|
||||
content1 = types.Content(
|
||||
role="user",
|
||||
parts=[types.Part.from_text(text="Search for Python tutorials")],
|
||||
)
|
||||
print(f"User: {content1.parts[0].text}")
|
||||
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session.id,
|
||||
new_message=content1,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f"Agent: {event.content.parts[0].text}")
|
||||
|
||||
# Test 2: Search with extra parameters (testing **kwargs)
|
||||
print("\n🧪 Test 2: Search with filters (**kwargs test)")
|
||||
print("-" * 60)
|
||||
content2 = types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=(
|
||||
"Search for machine learning articles, filtered by category"
|
||||
" 'technology', date_range 'last_month', and limit to 10"
|
||||
" results"
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
print(f"User: {content2.parts[0].text}")
|
||||
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session.id,
|
||||
new_message=content2,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f"Agent: {event.content.parts[0].text}")
|
||||
|
||||
# Verify success
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ Happy path test completed successfully!")
|
||||
print("=" * 60)
|
||||
print("\nVerified behaviors:")
|
||||
print(" ✅ CrewAI tool integrated with ADK agent")
|
||||
print(" ✅ Basic parameters passed correctly")
|
||||
print(" ✅ Extra parameters passed through **kwargs")
|
||||
print(" ✅ Tool executed and returned results")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -14,6 +14,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -21,6 +25,7 @@ from . import _automatic_function_calling_util
|
||||
from .function_tool import FunctionTool
|
||||
from .tool_configs import BaseToolConfig
|
||||
from .tool_configs import ToolArgsConfig
|
||||
from .tool_context import ToolContext
|
||||
|
||||
try:
|
||||
from crewai.tools import BaseTool as CrewaiBaseTool
|
||||
@@ -29,7 +34,7 @@ except ImportError as e:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
"Crewai Tools require Python 3.10+. Please upgrade your Python version."
|
||||
'Crewai Tools require Python 3.10+. Please upgrade your Python version.'
|
||||
) from e
|
||||
else:
|
||||
raise ImportError(
|
||||
@@ -55,12 +60,72 @@ class CrewaiTool(FunctionTool):
|
||||
elif tool.name:
|
||||
# Right now, CrewAI tool name contains white spaces. White spaces are
|
||||
# not supported in our framework. So we replace them with "_".
|
||||
self.name = tool.name.replace(" ", "_").lower()
|
||||
self.name = tool.name.replace(' ', '_').lower()
|
||||
if description:
|
||||
self.description = description
|
||||
elif tool.description:
|
||||
self.description = tool.description
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self, *, args: dict[str, Any], tool_context: ToolContext
|
||||
) -> Any:
|
||||
"""Override run_async to handle CrewAI-specific parameter filtering.
|
||||
|
||||
CrewAI tools use **kwargs pattern, so we need special parameter filtering
|
||||
logic that allows all parameters to pass through while removing only
|
||||
reserved parameters like 'self' and 'tool_context'.
|
||||
|
||||
Note: 'tool_context' is removed from the initial args dictionary to prevent
|
||||
duplicates, but is re-added if the function signature explicitly requires it
|
||||
as a parameter.
|
||||
"""
|
||||
# Preprocess arguments (includes Pydantic model conversion)
|
||||
args_to_call = self._preprocess_args(args)
|
||||
|
||||
signature = inspect.signature(self.func)
|
||||
valid_params = {param for param in signature.parameters}
|
||||
|
||||
# Check if function accepts **kwargs
|
||||
has_kwargs = any(
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for param in signature.parameters.values()
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
# For functions with **kwargs, we pass all arguments. We defensively
|
||||
# remove arguments like `self` that are managed by the framework and not
|
||||
# intended to be passed through **kwargs.
|
||||
args_to_call.pop('self', None)
|
||||
# We also remove `tool_context` that might have been passed in `args`,
|
||||
# as it will be explicitly injected later if it's a valid parameter.
|
||||
args_to_call.pop('tool_context', None)
|
||||
else:
|
||||
# For functions without **kwargs, use the original filtering.
|
||||
args_to_call = {
|
||||
k: v for k, v in args_to_call.items() if k in valid_params
|
||||
}
|
||||
|
||||
# Inject tool_context if it's an explicit parameter. This will add it
|
||||
# or overwrite any value that might have been passed in `args`.
|
||||
if 'tool_context' in valid_params:
|
||||
args_to_call['tool_context'] = tool_context
|
||||
|
||||
# Check for missing mandatory arguments
|
||||
mandatory_args = self._get_mandatory_args()
|
||||
missing_mandatory_args = [
|
||||
arg for arg in mandatory_args if arg not in args_to_call
|
||||
]
|
||||
|
||||
if missing_mandatory_args:
|
||||
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
|
||||
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
|
||||
{missing_mandatory_args_str}
|
||||
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
||||
return {'error': error_str}
|
||||
|
||||
return await self._invoke_callable(self.func, args_to_call)
|
||||
|
||||
@override
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool."""
|
||||
@@ -93,8 +158,8 @@ class CrewaiToolConfig(BaseToolConfig):
|
||||
tool: str
|
||||
"""The fully qualified path of the CrewAI tool instance."""
|
||||
|
||||
name: str = ""
|
||||
name: str = ''
|
||||
"""The name of the tool."""
|
||||
|
||||
description: str = ""
|
||||
description: str = ''
|
||||
"""The description of the tool."""
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip entire module if Python < 3.10 (must be before crewai_tool import)
|
||||
pytest.importorskip(
|
||||
"google.adk.tools.crewai_tool", reason="Requires Python 3.10+"
|
||||
)
|
||||
|
||||
from google.adk.agents.invocation_context import InvocationContext
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools.crewai_tool import CrewaiTool
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context() -> ToolContext:
|
||||
"""Fixture that provides a mock ToolContext for testing."""
|
||||
mock_invocation_context = MagicMock(spec=InvocationContext)
|
||||
mock_invocation_context.session = MagicMock(spec=Session)
|
||||
mock_invocation_context.session.state = MagicMock()
|
||||
return ToolContext(invocation_context=mock_invocation_context)
|
||||
|
||||
|
||||
def _simple_crewai_tool(*args, **kwargs):
|
||||
"""Simple CrewAI-style tool that accepts any keyword arguments."""
|
||||
return {
|
||||
"search_query": kwargs.get("search_query"),
|
||||
"other_param": kwargs.get("other_param"),
|
||||
}
|
||||
|
||||
|
||||
def _crewai_tool_with_context(tool_context: ToolContext, *args, **kwargs):
|
||||
"""CrewAI tool with explicit tool_context parameter."""
|
||||
return {
|
||||
"search_query": kwargs.get("search_query"),
|
||||
"tool_context_present": bool(tool_context),
|
||||
}
|
||||
|
||||
|
||||
class MockCrewaiBaseTool:
|
||||
"""Mock CrewAI BaseTool for testing."""
|
||||
|
||||
def __init__(self, run_func, name="mock_tool", description="Mock tool"):
|
||||
self.run = run_func
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.args_schema = MagicMock()
|
||||
self.args_schema.model_json_schema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_crewai_tool_initialization():
|
||||
"""Test CrewaiTool initialization with various parameters."""
|
||||
mock_crewai_tool = MockCrewaiBaseTool(_simple_crewai_tool)
|
||||
|
||||
# Test with custom name and description
|
||||
tool = CrewaiTool(
|
||||
mock_crewai_tool,
|
||||
name="custom_search_tool",
|
||||
description="Custom search tool description",
|
||||
)
|
||||
|
||||
assert tool.name == "custom_search_tool"
|
||||
assert tool.description == "Custom search tool description"
|
||||
assert tool.tool == mock_crewai_tool
|
||||
|
||||
|
||||
def test_crewai_tool_initialization_with_tool_defaults():
|
||||
"""Test CrewaiTool initialization using tool's default name and description."""
|
||||
mock_crewai_tool = MockCrewaiBaseTool(
|
||||
_simple_crewai_tool,
|
||||
name="Serper Dev Tool",
|
||||
description="Search the internet with Serper",
|
||||
)
|
||||
|
||||
# Test with empty name and description (should use tool defaults)
|
||||
tool = CrewaiTool(mock_crewai_tool, name="", description="")
|
||||
|
||||
assert (
|
||||
tool.name == "serper_dev_tool"
|
||||
) # Spaces replaced with underscores, lowercased
|
||||
assert tool.description == "Search the internet with Serper"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crewai_tool_basic_functionality(mock_tool_context):
|
||||
"""Test basic CrewaiTool functionality with **kwargs parameter passing."""
|
||||
mock_crewai_tool = MockCrewaiBaseTool(_simple_crewai_tool)
|
||||
tool = CrewaiTool(mock_crewai_tool, name="test_tool", description="Test tool")
|
||||
|
||||
# Test that **kwargs parameters are passed through correctly
|
||||
result = await tool.run_async(
|
||||
args={"search_query": "test query", "other_param": "test value"},
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
assert result["search_query"] == "test query"
|
||||
assert result["other_param"] == "test value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crewai_tool_with_tool_context(mock_tool_context):
|
||||
"""Test CrewaiTool with a tool that has explicit tool_context parameter."""
|
||||
mock_crewai_tool = MockCrewaiBaseTool(_crewai_tool_with_context)
|
||||
tool = CrewaiTool(
|
||||
mock_crewai_tool, name="context_tool", description="Context tool"
|
||||
)
|
||||
|
||||
# Test that tool_context is properly injected
|
||||
result = await tool.run_async(
|
||||
args={"search_query": "test query"},
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
assert result["search_query"] == "test query"
|
||||
assert result["tool_context_present"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crewai_tool_parameter_filtering(mock_tool_context):
|
||||
"""Test that CrewaiTool filters parameters for non-**kwargs functions."""
|
||||
|
||||
def explicit_params_func(arg1: str, arg2: int):
|
||||
"""Function with explicit parameters (no **kwargs)."""
|
||||
return {"arg1": arg1, "arg2": arg2}
|
||||
|
||||
mock_crewai_tool = MockCrewaiBaseTool(explicit_params_func)
|
||||
tool = CrewaiTool(
|
||||
mock_crewai_tool, name="explicit_tool", description="Explicit tool"
|
||||
)
|
||||
|
||||
# Test that unexpected parameters are filtered out
|
||||
result = await tool.run_async(
|
||||
args={
|
||||
"arg1": "test",
|
||||
"arg2": 42,
|
||||
"unexpected_param": "should_be_filtered",
|
||||
},
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
assert result == {"arg1": "test", "arg2": 42}
|
||||
# Verify unexpected parameter was filtered out
|
||||
assert "unexpected_param" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crewai_tool_get_declaration():
|
||||
"""Test that CrewaiTool properly builds function declarations."""
|
||||
mock_crewai_tool = MockCrewaiBaseTool(_simple_crewai_tool)
|
||||
tool = CrewaiTool(mock_crewai_tool, name="test_tool", description="Test tool")
|
||||
|
||||
# Test function declaration generation
|
||||
declaration = tool._get_declaration()
|
||||
|
||||
# Verify the declaration object structure and content
|
||||
assert declaration is not None
|
||||
assert declaration.name == "test_tool"
|
||||
assert declaration.description == "Test tool"
|
||||
assert declaration.parameters is not None
|
||||
|
||||
# Verify that the args_schema was used to build the declaration
|
||||
mock_crewai_tool.args_schema.model_json_schema.assert_called_once()
|
||||
@@ -22,6 +22,15 @@ from google.adk.tools.tool_context import ToolContext
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_context() -> ToolContext:
|
||||
"""Fixture that provides a mock ToolContext for testing."""
|
||||
mock_invocation_context = MagicMock(spec=InvocationContext)
|
||||
mock_invocation_context.session = MagicMock(spec=Session)
|
||||
mock_invocation_context.session.state = MagicMock()
|
||||
return ToolContext(invocation_context=mock_invocation_context)
|
||||
|
||||
|
||||
def function_for_testing_with_no_args():
|
||||
"""Function for testing with no args."""
|
||||
pass
|
||||
@@ -394,3 +403,28 @@ async def test_run_async_with_require_confirmation():
|
||||
tool_context=tool_context_mock,
|
||||
)
|
||||
assert result == {"received_arg": "hello"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_parameter_filtering(mock_tool_context):
|
||||
"""Test that parameter filtering works correctly for functions with explicit parameters."""
|
||||
|
||||
def explicit_params_func(arg1: str, arg2: int):
|
||||
"""Function with explicit parameters (no **kwargs)."""
|
||||
return {"arg1": arg1, "arg2": arg2}
|
||||
|
||||
tool = FunctionTool(explicit_params_func)
|
||||
|
||||
# Test that unexpected parameters are still filtered out for non-kwargs functions
|
||||
result = await tool.run_async(
|
||||
args={
|
||||
"arg1": "test",
|
||||
"arg2": 42,
|
||||
"unexpected_param": "should_be_filtered",
|
||||
},
|
||||
tool_context=mock_tool_context,
|
||||
)
|
||||
|
||||
assert result == {"arg1": "test", "arg2": 42}
|
||||
# Explicitly verify that unexpected_param was filtered out and not passed to the function
|
||||
assert "unexpected_param" not in result
|
||||
|
||||
Reference in New Issue
Block a user