You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
refactor: extract reusable functions from hitl and auth preprocessor
Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 877578253
This commit is contained in:
committed by
Copybara-Service
parent
8e79a12d6b
commit
c59afc21cb
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from typing_extensions import override
|
||||
@@ -25,6 +26,7 @@ from ..flows.llm_flows import functions
|
||||
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
from ..models.llm_request import LlmRequest
|
||||
from ..sessions.state import State
|
||||
from .auth_handler import AuthHandler
|
||||
from .auth_tool import AuthConfig
|
||||
from .auth_tool import AuthToolArguments
|
||||
@@ -35,6 +37,93 @@ from .auth_tool import AuthToolArguments
|
||||
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
|
||||
|
||||
|
||||
async def _store_auth_and_collect_resume_targets(
|
||||
events: list[Event],
|
||||
auth_fc_ids: set[str],
|
||||
auth_responses: dict[str, Any],
|
||||
state: State,
|
||||
) -> set[str]:
|
||||
"""Store auth credentials and return original function call IDs to resume.
|
||||
|
||||
Scans session events for ``adk_request_credential`` function calls whose
|
||||
IDs are in *auth_fc_ids*, extracts ``credential_key`` from their
|
||||
``AuthToolArguments`` args, merges ``credential_key`` into the
|
||||
corresponding auth response, stores credentials via ``AuthHandler``,
|
||||
and returns the set of original function call IDs that should be
|
||||
re-executed (excluding toolset auth).
|
||||
|
||||
Args:
|
||||
events: Session events to scan.
|
||||
auth_fc_ids: IDs of ``adk_request_credential`` function calls to match.
|
||||
auth_responses: Mapping of FC ID -> auth config response dict from the
|
||||
client.
|
||||
state: Session state for temporary credential storage.
|
||||
|
||||
Returns:
|
||||
Set of original function call IDs to resume.
|
||||
"""
|
||||
# Step 1: Scan events for matching adk_request_credential function calls
|
||||
# to extract AuthToolArguments (contains credential_key).
|
||||
requested_auth_config_by_id: dict[str, AuthConfig] = {}
|
||||
for event in events:
|
||||
event_function_calls = event.get_function_calls()
|
||||
if not event_function_calls:
|
||||
continue
|
||||
try:
|
||||
for function_call in event_function_calls:
|
||||
if (
|
||||
function_call.id in auth_fc_ids
|
||||
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
args = AuthToolArguments.model_validate(function_call.args)
|
||||
requested_auth_config_by_id[function_call.id] = args.auth_config
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
# Step 2: Store credentials. Merge credential_key from the original
|
||||
# request into the client's auth response before storing.
|
||||
for fc_id in auth_fc_ids:
|
||||
if fc_id not in auth_responses:
|
||||
continue
|
||||
auth_config = AuthConfig.model_validate(auth_responses[fc_id])
|
||||
requested_auth_config = requested_auth_config_by_id.get(fc_id)
|
||||
if (
|
||||
requested_auth_config
|
||||
and requested_auth_config.credential_key is not None
|
||||
):
|
||||
auth_config.credential_key = requested_auth_config.credential_key
|
||||
await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
|
||||
state=state
|
||||
)
|
||||
|
||||
# Step 3: Collect original function call IDs to resume, skipping
|
||||
# toolset auth entries which don't map to a resumable function call.
|
||||
tools_to_resume: set[str] = set()
|
||||
for fc_id in auth_fc_ids:
|
||||
requested_auth_config = requested_auth_config_by_id.get(fc_id)
|
||||
if not requested_auth_config:
|
||||
continue
|
||||
# Re-parse to get function_call_id (AuthConfig doesn't carry it;
|
||||
# AuthToolArguments does).
|
||||
for event in events:
|
||||
event_function_calls = event.get_function_calls()
|
||||
if not event_function_calls:
|
||||
continue
|
||||
for function_call in event_function_calls:
|
||||
if (
|
||||
function_call.id == fc_id
|
||||
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
args = AuthToolArguments.model_validate(function_call.args)
|
||||
if args.function_call_id.startswith(
|
||||
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
|
||||
):
|
||||
continue
|
||||
tools_to_resume.add(args.function_call_id)
|
||||
|
||||
return tools_to_resume
|
||||
|
||||
|
||||
class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles auth information to build the LLM request."""
|
||||
|
||||
@@ -49,8 +138,8 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
if not events:
|
||||
return
|
||||
|
||||
request_euc_function_call_ids = set()
|
||||
# find the last event with non-None content
|
||||
# Find the last user-authored event with function responses to
|
||||
# identify adk_request_credential responses.
|
||||
last_event_with_content = None
|
||||
for i in range(len(events) - 1, -1, -1):
|
||||
event = events[i]
|
||||
@@ -58,7 +147,6 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
last_event_with_content = event
|
||||
break
|
||||
|
||||
# check if the last event with content is authored by user
|
||||
if not last_event_with_content or last_event_with_content.author != 'user':
|
||||
return
|
||||
|
||||
@@ -66,104 +154,55 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
if not responses:
|
||||
return
|
||||
|
||||
requested_auth_config_by_request_id = {}
|
||||
# look for auth response
|
||||
# Collect adk_request_credential function response IDs and their
|
||||
# response dicts.
|
||||
auth_fc_ids: set[str] = set()
|
||||
auth_responses: dict[str, Any] = {}
|
||||
for function_call_response in responses:
|
||||
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
|
||||
continue
|
||||
# found the function call response for the system long running request euc
|
||||
# function call
|
||||
request_euc_function_call_ids.add(function_call_response.id)
|
||||
|
||||
if request_euc_function_call_ids:
|
||||
for event in events:
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
try:
|
||||
for function_call in function_calls:
|
||||
if (
|
||||
function_call.id in request_euc_function_call_ids
|
||||
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
):
|
||||
args = AuthToolArguments.model_validate(function_call.args)
|
||||
requested_auth_config_by_request_id[function_call.id] = (
|
||||
args.auth_config
|
||||
)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
for function_call_response in responses:
|
||||
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
|
||||
continue
|
||||
|
||||
auth_config = AuthConfig.model_validate(function_call_response.response)
|
||||
requested_auth_config = requested_auth_config_by_request_id.get(
|
||||
function_call_response.id
|
||||
)
|
||||
if (
|
||||
requested_auth_config
|
||||
and requested_auth_config.credential_key is not None
|
||||
):
|
||||
auth_config.credential_key = requested_auth_config.credential_key
|
||||
await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
|
||||
state=invocation_context.session.state
|
||||
auth_fc_ids.add(function_call_response.id)
|
||||
auth_responses[function_call_response.id] = (
|
||||
function_call_response.response
|
||||
)
|
||||
|
||||
if not request_euc_function_call_ids:
|
||||
if not auth_fc_ids:
|
||||
return
|
||||
|
||||
# Store credentials and collect tools to resume.
|
||||
tools_to_resume = await _store_auth_and_collect_resume_targets(
|
||||
events, auth_fc_ids, auth_responses, invocation_context.session.state
|
||||
)
|
||||
|
||||
if not tools_to_resume:
|
||||
return
|
||||
|
||||
# Find the original function call event and re-execute the tools
|
||||
# that needed auth.
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
event = events[i]
|
||||
# looking for the system long running request euc function call
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
tools_to_resume = set()
|
||||
|
||||
for function_call in function_calls:
|
||||
if function_call.id not in request_euc_function_call_ids:
|
||||
continue
|
||||
args = AuthToolArguments.model_validate(function_call.args)
|
||||
|
||||
# Skip toolset auth - auth response is already stored in session state
|
||||
# and we don't need to resume a function call for toolsets
|
||||
if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX):
|
||||
continue
|
||||
|
||||
tools_to_resume.add(args.function_call_id)
|
||||
if not tools_to_resume:
|
||||
continue
|
||||
|
||||
# found the system long running request euc function call
|
||||
# looking for original function call that requests euc
|
||||
for j in range(i - 1, -1, -1):
|
||||
event = events[j]
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
if any([
|
||||
function_call.id in tools_to_resume
|
||||
for function_call in function_calls
|
||||
]):
|
||||
if function_response_event := await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
# there could be parallel function calls that require auth
|
||||
# auth response would be a dict keyed by function call id
|
||||
tools_to_resume,
|
||||
):
|
||||
yield function_response_event
|
||||
return
|
||||
return
|
||||
if any([
|
||||
function_call.id in tools_to_resume
|
||||
for function_call in function_calls
|
||||
]):
|
||||
if function_response_event := await functions.handle_function_calls_async(
|
||||
invocation_context,
|
||||
event,
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
tools_to_resume,
|
||||
):
|
||||
yield function_response_event
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
request_processor = _AuthLlmRequestProcessor()
|
||||
|
||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -37,6 +38,65 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation:
|
||||
"""Parse ToolConfirmation from a function response dict.
|
||||
|
||||
Handles both the direct dict format and the ADK client's
|
||||
``{'response': json_string}`` wrapper format.
|
||||
|
||||
"""
|
||||
if response and len(response.values()) == 1 and 'response' in response.keys():
|
||||
return ToolConfirmation.model_validate(json.loads(response['response']))
|
||||
return ToolConfirmation.model_validate(response)
|
||||
|
||||
|
||||
def _resolve_confirmation_targets(
|
||||
events: list[Event],
|
||||
confirmation_fc_ids: set[str],
|
||||
confirmations_by_fc_id: dict[str, ToolConfirmation],
|
||||
) -> tuple[dict[str, ToolConfirmation], dict[str, types.FunctionCall]]:
|
||||
"""Find original function calls for confirmed tools.
|
||||
|
||||
Scans events for ``adk_request_confirmation`` function calls whose IDs
|
||||
are in *confirmation_fc_ids*, extracts the ``originalFunctionCall`` from
|
||||
their args, and maps each confirmation to the original FC ID.
|
||||
|
||||
Args:
|
||||
events: Session events to scan.
|
||||
confirmation_fc_ids: IDs of ``adk_request_confirmation`` function calls.
|
||||
confirmations_by_fc_id: Mapping of confirmation FC ID ->
|
||||
``ToolConfirmation``.
|
||||
|
||||
Returns:
|
||||
Tuple of ``(tool_confirmation_dict, original_fcs_dict)`` where both
|
||||
are keyed by the ORIGINAL function call IDs.
|
||||
"""
|
||||
tool_confirmation_dict: dict[str, ToolConfirmation] = {}
|
||||
original_fcs_dict: dict[str, types.FunctionCall] = {}
|
||||
|
||||
for event in events:
|
||||
event_function_calls = event.get_function_calls()
|
||||
if not event_function_calls:
|
||||
continue
|
||||
|
||||
for function_call in event_function_calls:
|
||||
if function_call.id not in confirmation_fc_ids:
|
||||
continue
|
||||
|
||||
args = function_call.args
|
||||
if 'originalFunctionCall' not in args:
|
||||
continue
|
||||
original_function_call = types.FunctionCall(
|
||||
**args['originalFunctionCall']
|
||||
)
|
||||
tool_confirmation_dict[original_function_call.id] = (
|
||||
confirmations_by_fc_id[function_call.id]
|
||||
)
|
||||
original_fcs_dict[original_function_call.id] = original_function_call
|
||||
|
||||
return tool_confirmation_dict, original_fcs_dict
|
||||
|
||||
|
||||
class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
"""Handles tool confirmation information to build the LLM request."""
|
||||
|
||||
@@ -53,14 +113,12 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
if not events:
|
||||
return
|
||||
|
||||
request_confirmation_function_responses = (
|
||||
dict()
|
||||
) # {function call id, tool confirmation}
|
||||
|
||||
# Step 1: Find the last user-authored event and parse confirmation
|
||||
# responses from it.
|
||||
confirmations_by_fc_id: dict[str, ToolConfirmation] = {}
|
||||
confirmation_event_index = -1
|
||||
for k in range(len(events) - 1, -1, -1):
|
||||
event = events[k]
|
||||
# Find the first event authored by user
|
||||
if not event.author or event.author != 'user':
|
||||
continue
|
||||
responses = event.get_function_responses()
|
||||
@@ -70,101 +128,58 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
|
||||
for function_response in responses:
|
||||
if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
|
||||
continue
|
||||
|
||||
# Find the FunctionResponse event that contains the user provided tool
|
||||
# confirmation
|
||||
if (
|
||||
confirmations_by_fc_id[function_response.id] = _parse_tool_confirmation(
|
||||
function_response.response
|
||||
and len(function_response.response.values()) == 1
|
||||
and 'response' in function_response.response.keys()
|
||||
):
|
||||
# ADK client must send a resuming run request with a function response
|
||||
# that always encapsulate the confirmation result with a 'response'
|
||||
# key
|
||||
tool_confirmation = ToolConfirmation.model_validate(
|
||||
json.loads(function_response.response['response'])
|
||||
)
|
||||
else:
|
||||
tool_confirmation = ToolConfirmation.model_validate(
|
||||
function_response.response
|
||||
)
|
||||
request_confirmation_function_responses[function_response.id] = (
|
||||
tool_confirmation
|
||||
)
|
||||
confirmation_event_index = k
|
||||
break
|
||||
|
||||
if not request_confirmation_function_responses:
|
||||
if not confirmations_by_fc_id:
|
||||
return
|
||||
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
# Step 2: Resolve confirmation targets using extracted helper.
|
||||
confirmation_fc_ids = set(confirmations_by_fc_id.keys())
|
||||
tools_to_resume_with_confirmation, tools_to_resume_with_args = (
|
||||
_resolve_confirmation_targets(
|
||||
events, confirmation_fc_ids, confirmations_by_fc_id
|
||||
)
|
||||
)
|
||||
|
||||
if not tools_to_resume_with_confirmation:
|
||||
return
|
||||
|
||||
# Step 3: Remove tools that have already been confirmed (dedup).
|
||||
for i in range(len(events) - 1, confirmation_event_index, -1):
|
||||
event = events[i]
|
||||
# Find the system generated FunctionCall event requesting the tool
|
||||
# confirmation
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
fr_list = event.get_function_responses()
|
||||
if not fr_list:
|
||||
continue
|
||||
|
||||
tools_to_resume_with_confirmation = (
|
||||
dict()
|
||||
) # {Function call id, tool confirmation}
|
||||
tools_to_resume_with_args = dict() # {Function call id, function calls}
|
||||
|
||||
for function_call in function_calls:
|
||||
if (
|
||||
function_call.id
|
||||
not in request_confirmation_function_responses.keys()
|
||||
):
|
||||
continue
|
||||
|
||||
args = function_call.args
|
||||
if 'originalFunctionCall' not in args:
|
||||
continue
|
||||
original_function_call = types.FunctionCall(
|
||||
**args['originalFunctionCall']
|
||||
)
|
||||
tools_to_resume_with_confirmation[original_function_call.id] = (
|
||||
request_confirmation_function_responses[function_call.id]
|
||||
)
|
||||
tools_to_resume_with_args[original_function_call.id] = (
|
||||
original_function_call
|
||||
)
|
||||
for function_response in fr_list:
|
||||
if function_response.id in tools_to_resume_with_confirmation:
|
||||
tools_to_resume_with_confirmation.pop(function_response.id)
|
||||
tools_to_resume_with_args.pop(function_response.id)
|
||||
if not tools_to_resume_with_confirmation:
|
||||
continue
|
||||
break
|
||||
|
||||
# Remove the tools that have already been confirmed.
|
||||
for i in range(len(events) - 1, confirmation_event_index, -1):
|
||||
event = events[i]
|
||||
function_response = event.get_function_responses()
|
||||
if not function_response:
|
||||
continue
|
||||
|
||||
for function_response in event.get_function_responses():
|
||||
if function_response.id in tools_to_resume_with_confirmation:
|
||||
tools_to_resume_with_confirmation.pop(function_response.id)
|
||||
tools_to_resume_with_args.pop(function_response.id)
|
||||
if not tools_to_resume_with_confirmation:
|
||||
break
|
||||
|
||||
if not tools_to_resume_with_confirmation:
|
||||
continue
|
||||
|
||||
if function_response_event := await functions.handle_function_call_list_async(
|
||||
invocation_context,
|
||||
tools_to_resume_with_args.values(),
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
# There could be parallel function calls that require input
|
||||
# response would be a dict keyed by function call id
|
||||
tools_to_resume_with_confirmation.keys(),
|
||||
tools_to_resume_with_confirmation,
|
||||
):
|
||||
yield function_response_event
|
||||
if not tools_to_resume_with_confirmation:
|
||||
return
|
||||
|
||||
# Step 4: Re-execute the confirmed tools.
|
||||
if function_response_event := await functions.handle_function_call_list_async(
|
||||
invocation_context,
|
||||
tools_to_resume_with_args.values(),
|
||||
{
|
||||
tool.name: tool
|
||||
for tool in await agent.canonical_tools(
|
||||
ReadonlyContext(invocation_context)
|
||||
)
|
||||
},
|
||||
tools_to_resume_with_confirmation.keys(),
|
||||
tools_to_resume_with_confirmation,
|
||||
):
|
||||
yield function_response_event
|
||||
return
|
||||
|
||||
|
||||
request_processor = _RequestConfirmationLlmRequestProcessor()
|
||||
|
||||
@@ -79,7 +79,9 @@ class TestAuthLlmRequestProcessor:
|
||||
@pytest.fixture
|
||||
def mock_auth_config(self):
|
||||
"""Create a mock AuthConfig."""
|
||||
return Mock(spec=AuthConfig)
|
||||
config = Mock(spec=AuthConfig)
|
||||
config.credential_key = None
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def mock_function_response_with_auth(self, mock_auth_config):
|
||||
@@ -347,10 +349,12 @@ class TestAuthLlmRequestProcessor:
|
||||
auth_response_1,
|
||||
auth_response_2,
|
||||
]
|
||||
user_event_with_multiple_responses.get_function_calls.return_value = []
|
||||
|
||||
# Create system function call events
|
||||
system_function_call_1 = Mock()
|
||||
system_function_call_1.id = 'auth_id_1'
|
||||
system_function_call_1.name = REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
system_function_call_1.args = {
|
||||
'function_call_id': 'tool_id_1',
|
||||
'auth_config': mock_auth_config,
|
||||
@@ -358,6 +362,7 @@ class TestAuthLlmRequestProcessor:
|
||||
|
||||
system_function_call_2 = Mock()
|
||||
system_function_call_2.id = 'auth_id_2'
|
||||
system_function_call_2.name = REQUEST_EUC_FUNCTION_CALL_NAME
|
||||
system_function_call_2.args = {
|
||||
'function_call_id': 'tool_id_2',
|
||||
'auth_config': mock_auth_config,
|
||||
|
||||
Reference in New Issue
Block a user