You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Enhance LangchainTool to accept more forms of functions
Now the LangchainTool can wrap: * Langchain StructuredTool (sync and async). * Langchain @Tool (sync and async). This enhance the flexibility for user and enables async functionalities. PiperOrigin-RevId: 784728061
This commit is contained in:
committed by
Copybara-Service
parent
f1e0bc0b18
commit
0ec69d05a4
@@ -17,20 +17,31 @@ This agent aims to test the Langchain tool with Langchain's StructuredTool
|
||||
"""
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.tools.langchain_tool import LangchainTool
|
||||
from langchain.tools import tool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def add(x, y) -> int:
|
||||
async def add(x, y) -> int:
|
||||
return x + y
|
||||
|
||||
|
||||
@tool
|
||||
def minus(x, y) -> int:
|
||||
return x - y
|
||||
|
||||
|
||||
class AddSchema(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
test_langchain_tool = StructuredTool.from_function(
|
||||
class MinusSchema(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
test_langchain_add_tool = StructuredTool.from_function(
|
||||
add,
|
||||
name="add",
|
||||
description="Adds two numbers",
|
||||
@@ -45,5 +56,8 @@ root_agent = Agent(
|
||||
"You are a helpful assistant for user questions, you have access to a"
|
||||
" tool that adds two numbers."
|
||||
),
|
||||
tools=[LangchainTool(tool=test_langchain_tool)],
|
||||
tools=[
|
||||
LangchainTool(tool=test_langchain_add_tool),
|
||||
LangchainTool(tool=minus),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -59,15 +59,26 @@ class LangchainTool(FunctionTool):
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
):
|
||||
# Check if the tool has a 'run' method
|
||||
if not hasattr(tool, 'run') and not hasattr(tool, '_run'):
|
||||
raise ValueError("Langchain tool must have a 'run' or '_run' method")
|
||||
raise ValueError(
|
||||
"Tool must be a Langchain tool, have a 'run' or '_run' method."
|
||||
)
|
||||
|
||||
# Determine which function to use
|
||||
if isinstance(tool, StructuredTool):
|
||||
func = tool.func
|
||||
else:
|
||||
# For async tools, func might be None but coroutine exists
|
||||
if func is None and hasattr(tool, 'coroutine') and tool.coroutine:
|
||||
func = tool.coroutine
|
||||
elif hasattr(tool, '_run') or hasattr(tool, 'run'):
|
||||
func = tool._run if hasattr(tool, '_run') else tool.run
|
||||
else:
|
||||
raise ValueError(
|
||||
"This is not supported. Tool must be a Langchain tool, have a 'run'"
|
||||
" or '_run' method. The tool is: ",
|
||||
type(tool),
|
||||
)
|
||||
|
||||
super().__init__(func)
|
||||
# run_manager is a special parameter for langchain tool
|
||||
self._ignore_params.append('run_manager')
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from google.adk.tools.langchain_tool import LangchainTool
|
||||
from langchain.tools import tool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
@tool
|
||||
async def async_add_with_annotation(x, y) -> int:
|
||||
"""Adds two numbers"""
|
||||
return x + y
|
||||
|
||||
|
||||
@tool
|
||||
def sync_add_with_annotation(x, y) -> int:
|
||||
"""Adds two numbers"""
|
||||
return x + y
|
||||
|
||||
|
||||
async def async_add(x, y) -> int:
|
||||
return x + y
|
||||
|
||||
|
||||
def sync_add(x, y) -> int:
|
||||
return x + y
|
||||
|
||||
|
||||
class AddSchema(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
|
||||
test_langchain_async_add_tool = StructuredTool.from_function(
|
||||
async_add,
|
||||
name="add",
|
||||
description="Adds two numbers",
|
||||
args_schema=AddSchema,
|
||||
)
|
||||
|
||||
test_langchain_sync_add_tool = StructuredTool.from_function(
|
||||
sync_add,
|
||||
name="add",
|
||||
description="Adds two numbers",
|
||||
args_schema=AddSchema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_async_function_works():
|
||||
"""Test that passing a raw async function to LangchainTool works correctly."""
|
||||
langchain_tool = LangchainTool(tool=test_langchain_async_add_tool)
|
||||
result = await langchain_tool.run_async(
|
||||
args={"x": 1, "y": 3}, tool_context=MagicMock()
|
||||
)
|
||||
assert result == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_sync_function_works():
|
||||
"""Test that passing a raw sync function to LangchainTool works correctly."""
|
||||
langchain_tool = LangchainTool(tool=test_langchain_sync_add_tool)
|
||||
result = await langchain_tool.run_async(
|
||||
args={"x": 1, "y": 3}, tool_context=MagicMock()
|
||||
)
|
||||
assert result == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_async_function_with_annotation_works():
|
||||
"""Test that passing a raw async function to LangchainTool works correctly."""
|
||||
langchain_tool = LangchainTool(tool=async_add_with_annotation)
|
||||
result = await langchain_tool.run_async(
|
||||
args={"x": 1, "y": 3}, tool_context=MagicMock()
|
||||
)
|
||||
assert result == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_sync_function_with_annotation_works():
|
||||
"""Test that passing a raw sync function to LangchainTool works correctly."""
|
||||
langchain_tool = LangchainTool(tool=sync_add_with_annotation)
|
||||
result = await langchain_tool.run_async(
|
||||
args={"x": 1, "y": 3}, tool_context=MagicMock()
|
||||
)
|
||||
assert result == 4
|
||||
Reference in New Issue
Block a user