From 7c282973ea193841fee79f90b8a91c5e02627ccc Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Wed, 14 Jan 2026 19:56:13 -0800 Subject: [PATCH] fix: Support Generator and AsyncGenerator tool declaration use yield type as return type Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 856459995 --- .../tools/_automatic_function_calling_util.py | 17 ++++ .../tools/test_from_function_with_options.py | 77 +++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/src/google/adk/tools/_automatic_function_calling_util.py b/src/google/adk/tools/_automatic_function_calling_util.py index 92df8871..2b00c799 100644 --- a/src/google/adk/tools/_automatic_function_calling_util.py +++ b/src/google/adk/tools/_automatic_function_calling_util.py @@ -14,12 +14,15 @@ from __future__ import annotations +import collections.abc import inspect from types import FunctionType import typing from typing import Any from typing import Callable from typing import Dict +from typing import get_args +from typing import get_origin from typing import Optional from typing import Union @@ -391,6 +394,20 @@ def from_function_with_options( return_annotation = inspect.signature(func).return_annotation + # Handle AsyncGenerator and Generator return types (streaming tools) + # AsyncGenerator[YieldType, SendType] -> use YieldType as response schema + # Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema + origin = get_origin(return_annotation) + if origin is not None and ( + origin is collections.abc.AsyncGenerator + or origin is collections.abc.Generator + ): + type_args = get_args(return_annotation) + if type_args: + # First type argument is the yield type + yield_type = type_args[0] + return_annotation = yield_type + # Handle functions with no return annotation if return_annotation is inspect._empty: # Functions with no return annotation can return any type diff --git a/tests/unittests/tools/test_from_function_with_options.py b/tests/unittests/tools/test_from_function_with_options.py index 61670a26..eae16453 100644 --- a/tests/unittests/tools/test_from_function_with_options.py +++ b/tests/unittests/tools/test_from_function_with_options.py @@ -14,7 +14,9 @@ from collections.abc import Sequence from typing import Any +from typing import AsyncGenerator from typing import Dict +from typing import Generator from google.adk.tools import _automatic_function_calling_util from google.adk.utils.variant_utils import GoogleLLMVariant @@ -242,3 +244,78 @@ def test_from_function_with_collections_return_type(): assert declaration.name == 'test_function' assert declaration.response.type == types.Type.ARRAY assert declaration.response.items.type == types.Type.STRING + + +def test_from_function_with_async_generator_return_vertex(): + """Test from_function_with_options with AsyncGenerator return for VERTEX_AI.""" + + async def test_function(param: str) -> AsyncGenerator[str, None]: + """A streaming function that yields strings.""" + yield param + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (str) from AsyncGenerator[str, None] + assert declaration.response is not None + assert declaration.response.type == types.Type.STRING + + +def test_from_function_with_async_generator_return_gemini(): + """Test from_function_with_options with AsyncGenerator return for GEMINI_API.""" + + async def test_function(param: str) -> AsyncGenerator[str, None]: + """A streaming function that yields strings.""" + yield param + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.GEMINI_API + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # GEMINI_API should not have response schema + assert declaration.response is None + + +def test_from_function_with_generator_return_vertex(): + """Test from_function_with_options with Generator return for VERTEX_AI.""" + + def test_function(param: str) -> Generator[int, None, None]: + """A streaming function that yields integers.""" + yield 42 + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (int) from Generator[int, None, None] + assert declaration.response is not None + assert declaration.response.type == types.Type.INTEGER + + +def test_from_function_with_async_generator_complex_yield_type_vertex(): + """Test from_function_with_options with AsyncGenerator yielding dict.""" + + async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]: + """A streaming function that yields dicts.""" + yield {'result': param} + + declaration = _automatic_function_calling_util.from_function_with_options( + test_function, GoogleLLMVariant.VERTEX_AI + ) + + assert declaration.name == 'test_function' + assert declaration.parameters.type == 'OBJECT' + assert declaration.parameters.properties['param'].type == 'STRING' + # VERTEX_AI should extract yield type (Dict[str, str]) from AsyncGenerator + assert declaration.response is not None + assert declaration.response.type == types.Type.OBJECT