ADK changes

PiperOrigin-RevId: 794403729
This commit is contained in:
Google Team Member
2025-08-12 22:28:43 -07:00
committed by Copybara-Service
parent e2518dc371
commit 114db93d70
22 changed files with 54 additions and 134 deletions
@@ -15,8 +15,6 @@
"""
This agent aims to test the Langchain tool with Langchain's StructuredTool
"""
from __future__ import annotations
from google.adk.agents.llm_agent import Agent
from google.adk.tools.langchain_tool import LangchainTool
from langchain.tools import tool
@@ -25,13 +23,11 @@ from pydantic import BaseModel
async def add(x, y) -> int:
"""Adds two numbers."""
return x + y
@tool
def minus(x, y) -> int:
"""Minus two numbers."""
return x - y
+1 -2
View File
@@ -109,8 +109,7 @@ ToolUnion: TypeAlias = Union[Callable, BaseTool, BaseToolset]
async def _convert_tool_union_to_tools(
tool_union: ToolUnion,
ctx: ReadonlyContext,
tool_union: ToolUnion, ctx: ReadonlyContext
) -> list[BaseTool]:
if isinstance(tool_union, BaseTool):
return [tool_union]
@@ -73,7 +73,7 @@ class BaseLlmFlow(ABC):
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Runs the flow using live api."""
llm_request = LlmRequest(live_connect_config=types.LiveConnectConfig())
llm_request = LlmRequest()
event_id = Event.new_id()
# Preprocess before calling the LLM.
@@ -373,9 +373,7 @@ class BaseLlmFlow(ABC):
yield event
async def _preprocess_async(
self,
invocation_context: InvocationContext,
llm_request: LlmRequest,
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
from ...agents.llm_agent import LlmAgent
+24 -25
View File
@@ -57,31 +57,30 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
if agent.output_schema and not agent.tools:
llm_request.set_output_schema(agent.output_schema)
if llm_request.live_connect_config:
llm_request.live_connect_config.response_modalities = (
invocation_context.run_config.response_modalities
)
llm_request.live_connect_config.speech_config = (
invocation_context.run_config.speech_config
)
llm_request.live_connect_config.output_audio_transcription = (
invocation_context.run_config.output_audio_transcription
)
llm_request.live_connect_config.input_audio_transcription = (
invocation_context.run_config.input_audio_transcription
)
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
llm_request.live_connect_config.enable_affective_dialog = (
invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
invocation_context.run_config.proactivity
)
llm_request.live_connect_config.session_resumption = (
invocation_context.run_config.session_resumption
)
llm_request.live_connect_config.response_modalities = (
invocation_context.run_config.response_modalities
)
llm_request.live_connect_config.speech_config = (
invocation_context.run_config.speech_config
)
llm_request.live_connect_config.output_audio_transcription = (
invocation_context.run_config.output_audio_transcription
)
llm_request.live_connect_config.input_audio_transcription = (
invocation_context.run_config.input_audio_transcription
)
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
llm_request.live_connect_config.enable_affective_dialog = (
invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
invocation_context.run_config.proactivity
)
llm_request.live_connect_config.session_resumption = (
invocation_context.run_config.session_resumption
)
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
+4 -24
View File
@@ -14,9 +14,6 @@
from __future__ import annotations
from collections.abc import AsyncGenerator as ABCAsyncGenerator
import inspect
from typing import get_origin
from typing import Optional
from google.genai import types
@@ -25,7 +22,6 @@ from pydantic import ConfigDict
from pydantic import Field
from ..tools.base_tool import BaseTool
from ..tools.function_tool import FunctionTool
def _find_tool_with_function_declarations(
@@ -70,13 +66,13 @@ class LlmRequest(BaseModel):
config: types.GenerateContentConfig = Field(
default_factory=types.GenerateContentConfig
)
live_connect_config: types.LiveConnectConfig = Field(
default_factory=types.LiveConnectConfig
)
"""Additional config for the generate content request.
tools in generate_content_config should not be set.
"""
live_connect_config: Optional[types.LiveConnectConfig] = None
"""Live connection config.
"""
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
"""The tools dictionary."""
@@ -103,23 +99,7 @@ class LlmRequest(BaseModel):
return
declarations = []
for tool in tools:
if self.live_connect_config is not None:
# ignore response for tools that returns AsyncGenerator that the model
# can't understand yet even though the model can't handle it, streaming
# tools can handle it.
# to check type, use typing.collections.abc.AsyncGenerator and not
# typing.AsyncGenerator
is_async_generator_return = False
if isinstance(tool, FunctionTool):
signature = inspect.signature(tool.func)
is_async_generator_return = (
get_origin(signature.return_annotation) is ABCAsyncGenerator
)
declaration = tool._get_declaration(
ignore_return_declaration=is_async_generator_return
)
else:
declaration = tool._get_declaration()
declaration = tool._get_declaration()
if declaration:
declarations.append(declaration)
self.tools_dict[tool.name] = tool
@@ -195,7 +195,6 @@ def build_function_declaration(
func: Union[Callable, BaseModel],
ignore_params: Optional[list[str]] = None,
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
ignore_return_declaration: bool = False,
) -> types.FunctionDeclaration:
signature = inspect.signature(func)
should_update_signature = False
@@ -233,11 +232,9 @@ def build_function_declaration(
new_func.__annotations__ = func.__annotations__
return (
from_function_with_options(func, variant, ignore_return_declaration)
from_function_with_options(func, variant)
if not should_update_signature
else from_function_with_options(
new_func, variant, ignore_return_declaration
)
else from_function_with_options(new_func, variant)
)
@@ -296,7 +293,6 @@ def build_function_declaration_util(
def from_function_with_options(
func: Callable,
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
ignore_return_declaration: bool = False,
) -> 'types.FunctionDeclaration':
parameters_properties = {}
@@ -328,8 +324,7 @@ def from_function_with_options(
declaration.parameters
)
)
if variant == GoogleLLMVariant.GEMINI_API or ignore_return_declaration:
if variant == GoogleLLMVariant.GEMINI_API:
return declaration
return_annotation = inspect.signature(func).return_annotation
+1 -4
View File
@@ -15,7 +15,6 @@
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
@@ -62,9 +61,7 @@ class AgentTool(BaseTool):
return data
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> types.FunctionDeclaration:
from ..agents.llm_agent import LlmAgent
from ..utils.variant_utils import GoogleLLMVariant
@@ -20,7 +20,7 @@ from typing import Dict
from typing import Optional
from typing import Union
from google.genai import types
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from ...auth.auth_credential import AuthCredential
@@ -115,9 +115,7 @@ class IntegrationConnectorTool(BaseTool):
self._auth_credential = auth_credential
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._rest_api_tool._operation_parser.get_json_schema()
for field in self.EXCLUDE_FIELDS:
@@ -128,7 +126,7 @@ class IntegrationConnectorTool(BaseTool):
schema_dict['required'].remove(field)
parameters = _to_gemini_schema(schema_dict)
function_decl = types.FunctionDeclaration(
function_decl = FunctionDeclaration(
name=self.name, description=self.description, parameters=parameters
)
return function_decl
+1 -3
View File
@@ -78,9 +78,7 @@ class BaseTool(ABC):
self.is_long_running = is_long_running
self.custom_metadata = custom_metadata
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
NOTE:
+1 -5
View File
@@ -14,8 +14,6 @@
from __future__ import annotations
from typing import Optional
from google.genai import types
from typing_extensions import override
@@ -64,9 +62,7 @@ class CrewaiTool(FunctionTool):
self.description = tool.description
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool."""
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
False,
+1 -4
View File
@@ -62,9 +62,7 @@ class FunctionTool(BaseTool):
self._ignore_params = ['tool_context', 'input_stream']
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
function_decl = types.FunctionDeclaration.model_validate(
build_function_declaration(
func=self.func,
@@ -72,7 +70,6 @@ class FunctionTool(BaseTool):
# input_stream is for streaming tool
ignore_params=self._ignore_params,
variant=self._api_variant,
ignore_return_declaration=ignore_return_declaration,
)
)
@@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import Optional
from google.genai import types
from google.genai.types import FunctionDeclaration
from typing_extensions import override
from ...auth.auth_credential import AuthCredential
@@ -51,9 +52,7 @@ class GoogleApiTool(BaseTool):
self.configure_auth(client_id, client_secret)
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> FunctionDeclaration:
return self._rest_api_tool._get_declaration()
@override
+1 -3
View File
@@ -101,9 +101,7 @@ class LangchainTool(FunctionTool):
# else: keep default from FunctionTool
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool.
Returns:
+1 -5
View File
@@ -16,7 +16,6 @@ from __future__ import annotations
import json
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
@@ -38,10 +37,7 @@ class LoadArtifactsTool(BaseTool):
description='Loads the artifacts and adds them to the session.',
)
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
+1 -4
View File
@@ -14,7 +14,6 @@
from __future__ import annotations
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
@@ -59,9 +58,7 @@ class LoadMemoryTool(FunctionTool):
super().__init__(load_memory)
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> types.FunctionDeclaration | None:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
+1 -3
View File
@@ -45,9 +45,7 @@ class LongRunningFunctionTool(FunctionTool):
self.is_long_running = True
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
declaration = super()._get_declaration()
if declaration:
instruction = (
+1 -4
View File
@@ -19,7 +19,6 @@ import logging
from typing import Optional
from fastapi.openapi.models import APIKeyIn
from google.genai import types
from google.genai.types import FunctionDeclaration
from typing_extensions import override
@@ -98,9 +97,7 @@ class MCPTool(BaseAuthenticatedTool):
self._mcp_session_manager = mcp_session_manager
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> FunctionDeclaration:
"""Gets the function declaration for the tool.
Returns:
@@ -23,7 +23,6 @@ from typing import Tuple
from typing import Union
from fastapi.openapi.models import Operation
from google.genai import types
from google.genai.types import FunctionDeclaration
import requests
from typing_extensions import override
@@ -182,9 +181,7 @@ class RestApiTool(BaseTool):
return RestApiTool.from_parsed_operation(operation)
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> FunctionDeclaration:
"""Returns the function declaration in the Gemini Schema format."""
schema_dict = self._operation_parser.get_json_schema()
parameters = _to_gemini_schema(schema_dict)
@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from google.genai import types
from typing_extensions import override
@@ -24,9 +21,7 @@ from ..base_tool import BaseTool
class BaseRetrievalTool(BaseTool):
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> types.FunctionDeclaration:
return types.FunctionDeclaration(
name=self.name,
description=self.description,
@@ -81,9 +81,7 @@ class SetModelResponseTool(BaseTool):
)
@override
def _get_declaration(
self, ignore_return_declaration: bool = False
) -> Optional[types.FunctionDeclaration]:
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
"""Gets the OpenAPI specification of this tool."""
function_decl = types.FunctionDeclaration.model_validate(
build_function_declaration(

Some files were not shown because too many files have changed in this diff Show More