feat: Added support for enums as arguments for function tools (#3088)

* feat: Added support for enums as arguments for function tools

* feat: Add default value support for function tools
fix: Add more test cases inside `test_build_function_declaration.py` for passing Enums as arguments

* fix: format code with pyink

---------

Co-authored-by: Wei Sun (Jack) <weisun@google.com>
Co-authored-by: Yvonne Yu <150068659+yyyu-google@users.noreply.github.com>
This commit is contained in:
SOORAJ TS
2025-10-28 23:05:02 +05:30
committed by GitHub
parent b17c8f19e5
commit 240ef5beea
2 changed files with 46 additions and 1 deletions
@@ -15,6 +15,7 @@
from __future__ import annotations
from enum import Enum
import inspect
import logging
import types as typing_types
@@ -75,7 +76,7 @@ def _raise_if_schema_unsupported(
):
if variant == GoogleLLMVariant.GEMINI_API:
_raise_for_any_of_if_mldev(schema)
_update_for_default_if_mldev(schema)
# _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value
def _is_default_value_compatible(
@@ -145,6 +146,20 @@ def _parse_schema_from_parameter(
schema.type = _py_builtin_type_to_schema_type[param.annotation]
_raise_if_schema_unsupported(variant, schema)
return schema
if isinstance(param.annotation, type) and issubclass(param.annotation, Enum):
schema.type = types.Type.STRING
schema.enum = [e.value for e in param.annotation]
if param.default is not inspect.Parameter.empty:
default_value = (
param.default.value
if isinstance(param.default, Enum)
else param.default
)
if default_value not in schema.enum:
raise ValueError(default_value_error_msg)
schema.default = default_value
_raise_if_schema_unsupported(variant, schema)
return schema
if (
get_origin(param.annotation) is Union
# only parse simple UnionType, example int | str | float | bool
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Dict
from typing import List
@@ -22,6 +23,7 @@ from google.genai import types
# TODO: crewai requires python 3.10 as minimum
# from crewai_tools import FileReadTool
from pydantic import BaseModel
import pytest
def test_string_input():
@@ -220,6 +222,34 @@ def test_list():
assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'
def test_enums():
class InputEnum(Enum):
AGENT = 'agent'
TOOL = 'tool'
def simple_function(input: InputEnum = InputEnum.AGENT):
return input.value
function_decl = _automatic_function_calling_util.build_function_declaration(
func=simple_function
)
assert function_decl.name == 'simple_function'
assert function_decl.parameters.type == 'OBJECT'
assert function_decl.parameters.properties['input'].type == 'STRING'
assert function_decl.parameters.properties['input'].default == 'agent'
assert function_decl.parameters.properties['input'].enum == ['agent', 'tool']
def simple_function_with_wrong_enum(input: InputEnum = 'WRONG_ENUM'):
return input.value
with pytest.raises(ValueError):
_automatic_function_calling_util.build_function_declaration(
func=simple_function_with_wrong_enum
)
def test_basemodel_list():
class ChildInput(BaseModel):
input_str: str