You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Support Generator and AsyncGenerator tool declaration
use yield type as return type Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 856459995
This commit is contained in:
committed by
Copybara-Service
parent
d4da1bb733
commit
7c282973ea
@@ -14,12 +14,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import inspect
|
||||
from types import FunctionType
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import get_args
|
||||
from typing import get_origin
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
@@ -391,6 +394,20 @@ def from_function_with_options(
|
||||
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
|
||||
# Handle AsyncGenerator and Generator return types (streaming tools)
|
||||
# AsyncGenerator[YieldType, SendType] -> use YieldType as response schema
|
||||
# Generator[YieldType, SendType, ReturnType] -> use YieldType as response schema
|
||||
origin = get_origin(return_annotation)
|
||||
if origin is not None and (
|
||||
origin is collections.abc.AsyncGenerator
|
||||
or origin is collections.abc.Generator
|
||||
):
|
||||
type_args = get_args(return_annotation)
|
||||
if type_args:
|
||||
# First type argument is the yield type
|
||||
yield_type = type_args[0]
|
||||
return_annotation = yield_type
|
||||
|
||||
# Handle functions with no return annotation
|
||||
if return_annotation is inspect._empty:
|
||||
# Functions with no return annotation can return any type
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import Dict
|
||||
from typing import Generator
|
||||
|
||||
from google.adk.tools import _automatic_function_calling_util
|
||||
from google.adk.utils.variant_utils import GoogleLLMVariant
|
||||
@@ -242,3 +244,78 @@ def test_from_function_with_collections_return_type():
|
||||
assert declaration.name == 'test_function'
|
||||
assert declaration.response.type == types.Type.ARRAY
|
||||
assert declaration.response.items.type == types.Type.STRING
|
||||
|
||||
|
||||
def test_from_function_with_async_generator_return_vertex():
|
||||
"""Test from_function_with_options with AsyncGenerator return for VERTEX_AI."""
|
||||
|
||||
async def test_function(param: str) -> AsyncGenerator[str, None]:
|
||||
"""A streaming function that yields strings."""
|
||||
yield param
|
||||
|
||||
declaration = _automatic_function_calling_util.from_function_with_options(
|
||||
test_function, GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
|
||||
assert declaration.name == 'test_function'
|
||||
assert declaration.parameters.type == 'OBJECT'
|
||||
assert declaration.parameters.properties['param'].type == 'STRING'
|
||||
# VERTEX_AI should extract yield type (str) from AsyncGenerator[str, None]
|
||||
assert declaration.response is not None
|
||||
assert declaration.response.type == types.Type.STRING
|
||||
|
||||
|
||||
def test_from_function_with_async_generator_return_gemini():
|
||||
"""Test from_function_with_options with AsyncGenerator return for GEMINI_API."""
|
||||
|
||||
async def test_function(param: str) -> AsyncGenerator[str, None]:
|
||||
"""A streaming function that yields strings."""
|
||||
yield param
|
||||
|
||||
declaration = _automatic_function_calling_util.from_function_with_options(
|
||||
test_function, GoogleLLMVariant.GEMINI_API
|
||||
)
|
||||
|
||||
assert declaration.name == 'test_function'
|
||||
assert declaration.parameters.type == 'OBJECT'
|
||||
assert declaration.parameters.properties['param'].type == 'STRING'
|
||||
# GEMINI_API should not have response schema
|
||||
assert declaration.response is None
|
||||
|
||||
|
||||
def test_from_function_with_generator_return_vertex():
|
||||
"""Test from_function_with_options with Generator return for VERTEX_AI."""
|
||||
|
||||
def test_function(param: str) -> Generator[int, None, None]:
|
||||
"""A streaming function that yields integers."""
|
||||
yield 42
|
||||
|
||||
declaration = _automatic_function_calling_util.from_function_with_options(
|
||||
test_function, GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
|
||||
assert declaration.name == 'test_function'
|
||||
assert declaration.parameters.type == 'OBJECT'
|
||||
assert declaration.parameters.properties['param'].type == 'STRING'
|
||||
# VERTEX_AI should extract yield type (int) from Generator[int, None, None]
|
||||
assert declaration.response is not None
|
||||
assert declaration.response.type == types.Type.INTEGER
|
||||
|
||||
|
||||
def test_from_function_with_async_generator_complex_yield_type_vertex():
|
||||
"""Test from_function_with_options with AsyncGenerator yielding dict."""
|
||||
|
||||
async def test_function(param: str) -> AsyncGenerator[Dict[str, str], None]:
|
||||
"""A streaming function that yields dicts."""
|
||||
yield {'result': param}
|
||||
|
||||
declaration = _automatic_function_calling_util.from_function_with_options(
|
||||
test_function, GoogleLLMVariant.VERTEX_AI
|
||||
)
|
||||
|
||||
assert declaration.name == 'test_function'
|
||||
assert declaration.parameters.type == 'OBJECT'
|
||||
assert declaration.parameters.properties['param'].type == 'STRING'
|
||||
# VERTEX_AI should extract yield type (Dict[str, str]) from AsyncGenerator
|
||||
assert declaration.response is not None
|
||||
assert declaration.response.type == types.Type.OBJECT
|
||||
|
||||
Reference in New Issue
Block a user