From 0ec69d05a4016adb72abf9c94f2e9ff4bdd1848c Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Fri, 18 Jul 2025 15:55:29 -0700 Subject: [PATCH] feat: Enhance LangchainTool to accept more forms of functions Now the LangchainTool can wrap: * Langchain StructuredTool (sync and async). * Langchain @Tool (sync and async). This enhance the flexibility for user and enables async functionalities. PiperOrigin-RevId: 784728061 --- .../langchain_structured_tool_agent/agent.py | 20 +++- src/google/adk/tools/langchain_tool.py | 17 ++- tests/unittests/tools/test_langchain_tool.py | 101 ++++++++++++++++++ 3 files changed, 132 insertions(+), 6 deletions(-) create mode 100644 tests/unittests/tools/test_langchain_tool.py diff --git a/contributing/samples/langchain_structured_tool_agent/agent.py b/contributing/samples/langchain_structured_tool_agent/agent.py index e9e3d232..b7119594 100644 --- a/contributing/samples/langchain_structured_tool_agent/agent.py +++ b/contributing/samples/langchain_structured_tool_agent/agent.py @@ -17,20 +17,31 @@ This agent aims to test the Langchain tool with Langchain's StructuredTool """ from google.adk.agents import Agent from google.adk.tools.langchain_tool import LangchainTool +from langchain.tools import tool from langchain_core.tools.structured import StructuredTool from pydantic import BaseModel -def add(x, y) -> int: +async def add(x, y) -> int: return x + y +@tool +def minus(x, y) -> int: + return x - y + + class AddSchema(BaseModel): x: int y: int -test_langchain_tool = StructuredTool.from_function( +class MinusSchema(BaseModel): + x: int + y: int + + +test_langchain_add_tool = StructuredTool.from_function( add, name="add", description="Adds two numbers", @@ -45,5 +56,8 @@ root_agent = Agent( "You are a helpful assistant for user questions, you have access to a" " tool that adds two numbers." ), - tools=[LangchainTool(tool=test_langchain_tool)], + tools=[ + LangchainTool(tool=test_langchain_add_tool), + LangchainTool(tool=minus), + ], ) diff --git a/src/google/adk/tools/langchain_tool.py b/src/google/adk/tools/langchain_tool.py index 1d56e440..1d91beb5 100644 --- a/src/google/adk/tools/langchain_tool.py +++ b/src/google/adk/tools/langchain_tool.py @@ -59,15 +59,26 @@ class LangchainTool(FunctionTool): name: Optional[str] = None, description: Optional[str] = None, ): - # Check if the tool has a 'run' method if not hasattr(tool, 'run') and not hasattr(tool, '_run'): - raise ValueError("Langchain tool must have a 'run' or '_run' method") + raise ValueError( + "Tool must be a Langchain tool, have a 'run' or '_run' method." + ) # Determine which function to use if isinstance(tool, StructuredTool): func = tool.func - else: + # For async tools, func might be None but coroutine exists + if func is None and hasattr(tool, 'coroutine') and tool.coroutine: + func = tool.coroutine + elif hasattr(tool, '_run') or hasattr(tool, 'run'): func = tool._run if hasattr(tool, '_run') else tool.run + else: + raise ValueError( + "This is not supported. Tool must be a Langchain tool, have a 'run'" + " or '_run' method. The tool is: ", + type(tool), + ) + super().__init__(func) # run_manager is a special parameter for langchain tool self._ignore_params.append('run_manager') diff --git a/tests/unittests/tools/test_langchain_tool.py b/tests/unittests/tools/test_langchain_tool.py new file mode 100644 index 00000000..998b3131 --- /dev/null +++ b/tests/unittests/tools/test_langchain_tool.py @@ -0,0 +1,101 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from google.adk.tools.langchain_tool import LangchainTool +from langchain.tools import tool +from langchain_core.tools.structured import StructuredTool +from pydantic import BaseModel +import pytest + + +@tool +async def async_add_with_annotation(x, y) -> int: + """Adds two numbers""" + return x + y + + +@tool +def sync_add_with_annotation(x, y) -> int: + """Adds two numbers""" + return x + y + + +async def async_add(x, y) -> int: + return x + y + + +def sync_add(x, y) -> int: + return x + y + + +class AddSchema(BaseModel): + x: int + y: int + + +test_langchain_async_add_tool = StructuredTool.from_function( + async_add, + name="add", + description="Adds two numbers", + args_schema=AddSchema, +) + +test_langchain_sync_add_tool = StructuredTool.from_function( + sync_add, + name="add", + description="Adds two numbers", + args_schema=AddSchema, +) + + +@pytest.mark.asyncio +async def test_raw_async_function_works(): + """Test that passing a raw async function to LangchainTool works correctly.""" + langchain_tool = LangchainTool(tool=test_langchain_async_add_tool) + result = await langchain_tool.run_async( + args={"x": 1, "y": 3}, tool_context=MagicMock() + ) + assert result == 4 + + +@pytest.mark.asyncio +async def test_raw_sync_function_works(): + """Test that passing a raw sync function to LangchainTool works correctly.""" + langchain_tool = LangchainTool(tool=test_langchain_sync_add_tool) + result = await langchain_tool.run_async( + args={"x": 1, "y": 3}, tool_context=MagicMock() + ) + assert result == 4 + + +@pytest.mark.asyncio +async def test_raw_async_function_with_annotation_works(): + """Test that passing a raw async function to LangchainTool works correctly.""" + langchain_tool = LangchainTool(tool=async_add_with_annotation) + result = await langchain_tool.run_async( + args={"x": 1, "y": 3}, tool_context=MagicMock() + ) + assert result == 4 + + +@pytest.mark.asyncio +async def test_raw_sync_function_with_annotation_works(): + """Test that passing a raw sync function to LangchainTool works correctly.""" + langchain_tool = LangchainTool(tool=sync_add_with_annotation) + result = await langchain_tool.run_async( + args={"x": 1, "y": 3}, tool_context=MagicMock() + ) + assert result == 4