You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
ADK changes
PiperOrigin-RevId: 794403729
This commit is contained in:
committed by
Copybara-Service
parent
e2518dc371
commit
114db93d70
@@ -15,8 +15,6 @@
|
||||
"""
|
||||
This agent aims to test the Langchain tool with Langchain's StructuredTool
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.tools.langchain_tool import LangchainTool
|
||||
from langchain.tools import tool
|
||||
@@ -25,13 +23,11 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
async def add(x, y) -> int:
|
||||
"""Adds two numbers."""
|
||||
return x + y
|
||||
|
||||
|
||||
@tool
|
||||
def minus(x, y) -> int:
|
||||
"""Minus two numbers."""
|
||||
return x - y
|
||||
|
||||
|
||||
|
||||
@@ -109,8 +109,7 @@ ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
|
||||
|
||||
|
||||
async def _convert_tool_union_to_tools(
|
||||
tool_union: ToolUnion,
|
||||
ctx: ReadonlyContext,
|
||||
tool_union: ToolUnion, ctx: ReadonlyContext
|
||||
) -> list[BaseTool]:
|
||||
if isinstance(tool_union, BaseTool):
|
||||
return [tool_union]
|
||||
|
||||
@@ -73,7 +73,7 @@ class BaseLlmFlow(ABC):
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Runs the flow using live api."""
|
||||
llm_request = LlmRequest(live_connect_config=types.LiveConnectConfig())
|
||||
llm_request = LlmRequest()
|
||||
event_id = Event.new_id()
|
||||
|
||||
# Preprocess before calling the LLM.
|
||||
@@ -373,9 +373,7 @@ class BaseLlmFlow(ABC):
|
||||
yield event
|
||||
|
||||
async def _preprocess_async(
|
||||
self,
|
||||
invocation_context: InvocationContext,
|
||||
llm_request: LlmRequest,
|
||||
self, invocation_context: InvocationContext, llm_request: LlmRequest
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
|
||||
|
||||
@@ -57,31 +57,30 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
if agent.output_schema and not agent.tools:
|
||||
llm_request.set_output_schema(agent.output_schema)
|
||||
|
||||
if llm_request.live_connect_config:
|
||||
llm_request.live_connect_config.response_modalities = (
|
||||
invocation_context.run_config.response_modalities
|
||||
)
|
||||
llm_request.live_connect_config.speech_config = (
|
||||
invocation_context.run_config.speech_config
|
||||
)
|
||||
llm_request.live_connect_config.output_audio_transcription = (
|
||||
invocation_context.run_config.output_audio_transcription
|
||||
)
|
||||
llm_request.live_connect_config.input_audio_transcription = (
|
||||
invocation_context.run_config.input_audio_transcription
|
||||
)
|
||||
llm_request.live_connect_config.realtime_input_config = (
|
||||
invocation_context.run_config.realtime_input_config
|
||||
)
|
||||
llm_request.live_connect_config.enable_affective_dialog = (
|
||||
invocation_context.run_config.enable_affective_dialog
|
||||
)
|
||||
llm_request.live_connect_config.proactivity = (
|
||||
invocation_context.run_config.proactivity
|
||||
)
|
||||
llm_request.live_connect_config.session_resumption = (
|
||||
invocation_context.run_config.session_resumption
|
||||
)
|
||||
llm_request.live_connect_config.response_modalities = (
|
||||
invocation_context.run_config.response_modalities
|
||||
)
|
||||
llm_request.live_connect_config.speech_config = (
|
||||
invocation_context.run_config.speech_config
|
||||
)
|
||||
llm_request.live_connect_config.output_audio_transcription = (
|
||||
invocation_context.run_config.output_audio_transcription
|
||||
)
|
||||
llm_request.live_connect_config.input_audio_transcription = (
|
||||
invocation_context.run_config.input_audio_transcription
|
||||
)
|
||||
llm_request.live_connect_config.realtime_input_config = (
|
||||
invocation_context.run_config.realtime_input_config
|
||||
)
|
||||
llm_request.live_connect_config.enable_affective_dialog = (
|
||||
invocation_context.run_config.enable_affective_dialog
|
||||
)
|
||||
llm_request.live_connect_config.proactivity = (
|
||||
invocation_context.run_config.proactivity
|
||||
)
|
||||
llm_request.live_connect_config.session_resumption = (
|
||||
invocation_context.run_config.session_resumption
|
||||
)
|
||||
|
||||
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
|
||||
|
||||
|
||||
@@ -14,9 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator as ABCAsyncGenerator
|
||||
import inspect
|
||||
from typing import get_origin
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
@@ -25,7 +22,6 @@ from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..tools.base_tool import BaseTool
|
||||
from ..tools.function_tool import FunctionTool
|
||||
|
||||
|
||||
def _find_tool_with_function_declarations(
|
||||
@@ -70,13 +66,13 @@ class LlmRequest(BaseModel):
|
||||
config: types.GenerateContentConfig = Field(
|
||||
default_factory=types.GenerateContentConfig
|
||||
)
|
||||
live_connect_config: types.LiveConnectConfig = Field(
|
||||
default_factory=types.LiveConnectConfig
|
||||
)
|
||||
"""Additional config for the generate content request.
|
||||
|
||||
tools in generate_content_config should not be set.
|
||||
"""
|
||||
live_connect_config: Optional[types.LiveConnectConfig] = None
|
||||
"""Live connection config.
|
||||
"""
|
||||
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
|
||||
"""The tools dictionary."""
|
||||
|
||||
@@ -103,23 +99,7 @@ class LlmRequest(BaseModel):
|
||||
return
|
||||
declarations = []
|
||||
for tool in tools:
|
||||
if self.live_connect_config is not None:
|
||||
# ignore response for tools that returns AsyncGenerator that the model
|
||||
# can't understand yet even though the model can't handle it, streaming
|
||||
# tools can handle it.
|
||||
# to check type, use typing.collections.abc.AsyncGenerator and not
|
||||
# typing.AsyncGenerator
|
||||
is_async_generator_return = False
|
||||
if isinstance(tool, FunctionTool):
|
||||
signature = inspect.signature(tool.func)
|
||||
is_async_generator_return = (
|
||||
get_origin(signature.return_annotation) is ABCAsyncGenerator
|
||||
)
|
||||
declaration = tool._get_declaration(
|
||||
ignore_return_declaration=is_async_generator_return
|
||||
)
|
||||
else:
|
||||
declaration = tool._get_declaration()
|
||||
declaration = tool._get_declaration()
|
||||
if declaration:
|
||||
declarations.append(declaration)
|
||||
self.tools_dict[tool.name] = tool
|
||||
|
||||
@@ -195,7 +195,6 @@ def build_function_declaration(
|
||||
func: Union[Callable, BaseModel],
|
||||
ignore_params: Optional[list[str]] = None,
|
||||
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
|
||||
ignore_return_declaration: bool = False,
|
||||
) -> types.FunctionDeclaration:
|
||||
signature = inspect.signature(func)
|
||||
should_update_signature = False
|
||||
@@ -233,11 +232,9 @@ def build_function_declaration(
|
||||
new_func.__annotations__ = func.__annotations__
|
||||
|
||||
return (
|
||||
from_function_with_options(func, variant, ignore_return_declaration)
|
||||
from_function_with_options(func, variant)
|
||||
if not should_update_signature
|
||||
else from_function_with_options(
|
||||
new_func, variant, ignore_return_declaration
|
||||
)
|
||||
else from_function_with_options(new_func, variant)
|
||||
)
|
||||
|
||||
|
||||
@@ -296,7 +293,6 @@ def build_function_declaration_util(
|
||||
def from_function_with_options(
|
||||
func: Callable,
|
||||
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
|
||||
ignore_return_declaration: bool = False,
|
||||
) -> 'types.FunctionDeclaration':
|
||||
|
||||
parameters_properties = {}
|
||||
@@ -328,8 +324,7 @@ def from_function_with_options(
|
||||
declaration.parameters
|
||||
)
|
||||
)
|
||||
|
||||
if variant == GoogleLLMVariant.GEMINI_API or ignore_return_declaration:
|
||||
if variant == GoogleLLMVariant.GEMINI_API:
|
||||
return declaration
|
||||
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
@@ -62,9 +61,7 @@ class AgentTool(BaseTool):
|
||||
return data
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
from ..agents.llm_agent import LlmAgent
|
||||
from ..utils.variant_utils import GoogleLLMVariant
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
@@ -115,9 +115,7 @@ class IntegrationConnectorTool(BaseTool):
|
||||
self._auth_credential = auth_credential
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self._rest_api_tool._operation_parser.get_json_schema()
|
||||
for field in self.EXCLUDE_FIELDS:
|
||||
@@ -128,7 +126,7 @@ class IntegrationConnectorTool(BaseTool):
|
||||
schema_dict['required'].remove(field)
|
||||
|
||||
parameters = _to_gemini_schema(schema_dict)
|
||||
function_decl = types.FunctionDeclaration(
|
||||
function_decl = FunctionDeclaration(
|
||||
name=self.name, description=self.description, parameters=parameters
|
||||
)
|
||||
return function_decl
|
||||
|
||||
@@ -78,9 +78,7 @@ class BaseTool(ABC):
|
||||
self.is_long_running = is_long_running
|
||||
self.custom_metadata = custom_metadata
|
||||
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
|
||||
|
||||
NOTE:
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -64,9 +62,7 @@ class CrewaiTool(FunctionTool):
|
||||
self.description = tool.description
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool."""
|
||||
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
|
||||
False,
|
||||
|
||||
@@ -62,9 +62,7 @@ class FunctionTool(BaseTool):
|
||||
self._ignore_params = ['tool_context', 'input_stream']
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
function_decl = types.FunctionDeclaration.model_validate(
|
||||
build_function_declaration(
|
||||
func=self.func,
|
||||
@@ -72,7 +70,6 @@ class FunctionTool(BaseTool):
|
||||
# input_stream is for streaming tool
|
||||
ignore_params=self._ignore_params,
|
||||
variant=self._api_variant,
|
||||
ignore_return_declaration=ignore_return_declaration,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -11,13 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
@@ -51,9 +52,7 @@ class GoogleApiTool(BaseTool):
|
||||
self.configure_auth(client_id, client_secret)
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
return self._rest_api_tool._get_declaration()
|
||||
|
||||
@override
|
||||
|
||||
@@ -101,9 +101,7 @@ class LangchainTool(FunctionTool):
|
||||
# else: keep default from FunctionTool
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
"""Build the function declaration for the tool.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
@@ -38,10 +37,7 @@ class LoadArtifactsTool(BaseTool):
|
||||
description='Loads the artifacts and adds them to the session.',
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
@@ -59,9 +58,7 @@ class LoadMemoryTool(FunctionTool):
|
||||
super().__init__(load_memory)
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
|
||||
@@ -45,9 +45,7 @@ class LongRunningFunctionTool(FunctionTool):
|
||||
self.is_long_running = True
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
declaration = super()._get_declaration()
|
||||
if declaration:
|
||||
instruction = (
|
||||
|
||||
@@ -19,7 +19,6 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.genai import types
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -98,9 +97,7 @@ class MCPTool(BaseAuthenticatedTool):
|
||||
self._mcp_session_manager = mcp_session_manager
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Gets the function declaration for the tool.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -23,7 +23,6 @@ from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from fastapi.openapi.models import Operation
|
||||
from google.genai import types
|
||||
from google.genai.types import FunctionDeclaration
|
||||
import requests
|
||||
from typing_extensions import override
|
||||
@@ -182,9 +181,7 @@ class RestApiTool(BaseTool):
|
||||
return RestApiTool.from_parsed_operation(operation)
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> FunctionDeclaration:
|
||||
"""Returns the function declaration in the Gemini Schema format."""
|
||||
schema_dict = self._operation_parser.get_json_schema()
|
||||
parameters = _to_gemini_schema(schema_dict)
|
||||
|
||||
@@ -11,9 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
@@ -24,9 +21,7 @@ from ..base_tool import BaseTool
|
||||
class BaseRetrievalTool(BaseTool):
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> types.FunctionDeclaration:
|
||||
return types.FunctionDeclaration(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
|
||||
@@ -81,9 +81,7 @@ class SetModelResponseTool(BaseTool):
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_declaration(
|
||||
self, ignore_return_declaration: bool = False
|
||||
) -> Optional[types.FunctionDeclaration]:
|
||||
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
||||
"""Gets the OpenAPI specification of this tool."""
|
||||
function_decl = types.FunctionDeclaration.model_validate(
|
||||
build_function_declaration(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user