You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Added support for enums as arguments for function tools (#3088)
* feat: Added support for enums as arguments for function tools * feat: Add default value support for function tools fix: Add more test cases inside `test_build_function_declaration.py` for passing Enums as arguments * fix: format code with pyink --------- Co-authored-by: Wei Sun (Jack) <weisun@google.com> Co-authored-by: Yvonne Yu <150068659+yyyu-google@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import logging
|
||||
import types as typing_types
|
||||
@@ -75,7 +76,7 @@ def _raise_if_schema_unsupported(
|
||||
):
|
||||
if variant == GoogleLLMVariant.GEMINI_API:
|
||||
_raise_for_any_of_if_mldev(schema)
|
||||
_update_for_default_if_mldev(schema)
|
||||
# _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value
|
||||
|
||||
|
||||
def _is_default_value_compatible(
|
||||
@@ -145,6 +146,20 @@ def _parse_schema_from_parameter(
|
||||
schema.type = _py_builtin_type_to_schema_type[param.annotation]
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if isinstance(param.annotation, type) and issubclass(param.annotation, Enum):
|
||||
schema.type = types.Type.STRING
|
||||
schema.enum = [e.value for e in param.annotation]
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
default_value = (
|
||||
param.default.value
|
||||
if isinstance(param.default, Enum)
|
||||
else param.default
|
||||
)
|
||||
if default_value not in schema.enum:
|
||||
raise ValueError(default_value_error_msg)
|
||||
schema.default = default_value
|
||||
_raise_if_schema_unsupported(variant, schema)
|
||||
return schema
|
||||
if (
|
||||
get_origin(param.annotation) is Union
|
||||
# only parse simple UnionType, example int | str | float | bool
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
@@ -22,6 +23,7 @@ from google.genai import types
|
||||
# TODO: crewai requires python 3.10 as minimum
|
||||
# from crewai_tools import FileReadTool
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
def test_string_input():
|
||||
@@ -220,6 +222,34 @@ def test_list():
|
||||
assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'
|
||||
|
||||
|
||||
def test_enums():
|
||||
|
||||
class InputEnum(Enum):
|
||||
AGENT = 'agent'
|
||||
TOOL = 'tool'
|
||||
|
||||
def simple_function(input: InputEnum = InputEnum.AGENT):
|
||||
return input.value
|
||||
|
||||
function_decl = _automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function
|
||||
)
|
||||
|
||||
assert function_decl.name == 'simple_function'
|
||||
assert function_decl.parameters.type == 'OBJECT'
|
||||
assert function_decl.parameters.properties['input'].type == 'STRING'
|
||||
assert function_decl.parameters.properties['input'].default == 'agent'
|
||||
assert function_decl.parameters.properties['input'].enum == ['agent', 'tool']
|
||||
|
||||
def simple_function_with_wrong_enum(input: InputEnum = 'WRONG_ENUM'):
|
||||
return input.value
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_automatic_function_calling_util.build_function_declaration(
|
||||
func=simple_function_with_wrong_enum
|
||||
)
|
||||
|
||||
|
||||
def test_basemodel_list():
|
||||
class ChildInput(BaseModel):
|
||||
input_str: str
|
||||
|
||||
Reference in New Issue
Block a user