refactor: extract reusable functions from hitl and auth preprocessor

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 877578253
This commit is contained in:
Xiang (Sean) Zhou
2026-03-02 14:27:07 -08:00
committed by Copybara-Service
parent 8e79a12d6b
commit c59afc21cb
3 changed files with 235 additions and 176 deletions
+127 -88
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
from typing import Any
from typing import AsyncGenerator
from typing_extensions import override
@@ -25,6 +26,7 @@ from ..flows.llm_flows import functions
from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor
from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
from ..models.llm_request import LlmRequest
from ..sessions.state import State
from .auth_handler import AuthHandler
from .auth_tool import AuthConfig
from .auth_tool import AuthToolArguments
@@ -35,6 +37,93 @@ from .auth_tool import AuthToolArguments
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
async def _store_auth_and_collect_resume_targets(
events: list[Event],
auth_fc_ids: set[str],
auth_responses: dict[str, Any],
state: State,
) -> set[str]:
"""Store auth credentials and return original function call IDs to resume.
Scans session events for ``adk_request_credential`` function calls whose
IDs are in *auth_fc_ids*, extracts ``credential_key`` from their
``AuthToolArguments`` args, merges ``credential_key`` into the
corresponding auth response, stores credentials via ``AuthHandler``,
and returns the set of original function call IDs that should be
re-executed (excluding toolset auth).
Args:
events: Session events to scan.
auth_fc_ids: IDs of ``adk_request_credential`` function calls to match.
auth_responses: Mapping of FC ID -> auth config response dict from the
client.
state: Session state for temporary credential storage.
Returns:
Set of original function call IDs to resume.
"""
# Step 1: Scan events for matching adk_request_credential function calls
# to extract AuthToolArguments (contains credential_key).
requested_auth_config_by_id: dict[str, AuthConfig] = {}
for event in events:
event_function_calls = event.get_function_calls()
if not event_function_calls:
continue
try:
for function_call in event_function_calls:
if (
function_call.id in auth_fc_ids
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
args = AuthToolArguments.model_validate(function_call.args)
requested_auth_config_by_id[function_call.id] = args.auth_config
except TypeError:
continue
# Step 2: Store credentials. Merge credential_key from the original
# request into the client's auth response before storing.
for fc_id in auth_fc_ids:
if fc_id not in auth_responses:
continue
auth_config = AuthConfig.model_validate(auth_responses[fc_id])
requested_auth_config = requested_auth_config_by_id.get(fc_id)
if (
requested_auth_config
and requested_auth_config.credential_key is not None
):
auth_config.credential_key = requested_auth_config.credential_key
await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
state=state
)
# Step 3: Collect original function call IDs to resume, skipping
# toolset auth entries which don't map to a resumable function call.
tools_to_resume: set[str] = set()
for fc_id in auth_fc_ids:
requested_auth_config = requested_auth_config_by_id.get(fc_id)
if not requested_auth_config:
continue
# Re-parse to get function_call_id (AuthConfig doesn't carry it;
# AuthToolArguments does).
for event in events:
event_function_calls = event.get_function_calls()
if not event_function_calls:
continue
for function_call in event_function_calls:
if (
function_call.id == fc_id
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
args = AuthToolArguments.model_validate(function_call.args)
if args.function_call_id.startswith(
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX
):
continue
tools_to_resume.add(args.function_call_id)
return tools_to_resume
class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
"""Handles auth information to build the LLM request."""
@@ -49,8 +138,8 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
if not events:
return
request_euc_function_call_ids = set()
# find the last event with non-None content
# Find the last user-authored event with function responses to
# identify adk_request_credential responses.
last_event_with_content = None
for i in range(len(events) - 1, -1, -1):
event = events[i]
@@ -58,7 +147,6 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
last_event_with_content = event
break
# check if the last event with content is authored by user
if not last_event_with_content or last_event_with_content.author != 'user':
return
@@ -66,104 +154,55 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor):
if not responses:
return
requested_auth_config_by_request_id = {}
# look for auth response
# Collect adk_request_credential function response IDs and their
# response dicts.
auth_fc_ids: set[str] = set()
auth_responses: dict[str, Any] = {}
for function_call_response in responses:
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
continue
# found the function call response for the system long running request euc
# function call
request_euc_function_call_ids.add(function_call_response.id)
if request_euc_function_call_ids:
for event in events:
function_calls = event.get_function_calls()
if not function_calls:
continue
try:
for function_call in function_calls:
if (
function_call.id in request_euc_function_call_ids
and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME
):
args = AuthToolArguments.model_validate(function_call.args)
requested_auth_config_by_request_id[function_call.id] = (
args.auth_config
)
except TypeError:
continue
for function_call_response in responses:
if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME:
continue
auth_config = AuthConfig.model_validate(function_call_response.response)
requested_auth_config = requested_auth_config_by_request_id.get(
function_call_response.id
)
if (
requested_auth_config
and requested_auth_config.credential_key is not None
):
auth_config.credential_key = requested_auth_config.credential_key
await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
state=invocation_context.session.state
auth_fc_ids.add(function_call_response.id)
auth_responses[function_call_response.id] = (
function_call_response.response
)
if not request_euc_function_call_ids:
if not auth_fc_ids:
return
# Store credentials and collect tools to resume.
tools_to_resume = await _store_auth_and_collect_resume_targets(
events, auth_fc_ids, auth_responses, invocation_context.session.state
)
if not tools_to_resume:
return
# Find the original function call event and re-execute the tools
# that needed auth.
for i in range(len(events) - 2, -1, -1):
event = events[i]
# looking for the system long running request euc function call
function_calls = event.get_function_calls()
if not function_calls:
continue
tools_to_resume = set()
for function_call in function_calls:
if function_call.id not in request_euc_function_call_ids:
continue
args = AuthToolArguments.model_validate(function_call.args)
# Skip toolset auth - auth response is already stored in session state
# and we don't need to resume a function call for toolsets
if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX):
continue
tools_to_resume.add(args.function_call_id)
if not tools_to_resume:
continue
# found the system long running request euc function call
# looking for original function call that requests euc
for j in range(i - 1, -1, -1):
event = events[j]
function_calls = event.get_function_calls()
if not function_calls:
continue
if any([
function_call.id in tools_to_resume
for function_call in function_calls
]):
if function_response_event := await functions.handle_function_calls_async(
invocation_context,
event,
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# there could be parallel function calls that require auth
# auth response would be a dict keyed by function call id
tools_to_resume,
):
yield function_response_event
return
return
if any([
function_call.id in tools_to_resume
for function_call in function_calls
]):
if function_response_event := await functions.handle_function_calls_async(
invocation_context,
event,
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
tools_to_resume,
):
yield function_response_event
return
return
request_processor = _AuthLlmRequestProcessor()
@@ -15,6 +15,7 @@ from __future__ import annotations
import json
import logging
from typing import Any
from typing import AsyncGenerator
from typing import TYPE_CHECKING
@@ -37,6 +38,65 @@ if TYPE_CHECKING:
logger = logging.getLogger('google_adk.' + __name__)
def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation:
"""Parse ToolConfirmation from a function response dict.
Handles both the direct dict format and the ADK client's
``{'response': json_string}`` wrapper format.
"""
if response and len(response.values()) == 1 and 'response' in response.keys():
return ToolConfirmation.model_validate(json.loads(response['response']))
return ToolConfirmation.model_validate(response)
def _resolve_confirmation_targets(
events: list[Event],
confirmation_fc_ids: set[str],
confirmations_by_fc_id: dict[str, ToolConfirmation],
) -> tuple[dict[str, ToolConfirmation], dict[str, types.FunctionCall]]:
"""Find original function calls for confirmed tools.
Scans events for ``adk_request_confirmation`` function calls whose IDs
are in *confirmation_fc_ids*, extracts the ``originalFunctionCall`` from
their args, and maps each confirmation to the original FC ID.
Args:
events: Session events to scan.
confirmation_fc_ids: IDs of ``adk_request_confirmation`` function calls.
confirmations_by_fc_id: Mapping of confirmation FC ID ->
``ToolConfirmation``.
Returns:
Tuple of ``(tool_confirmation_dict, original_fcs_dict)`` where both
are keyed by the ORIGINAL function call IDs.
"""
tool_confirmation_dict: dict[str, ToolConfirmation] = {}
original_fcs_dict: dict[str, types.FunctionCall] = {}
for event in events:
event_function_calls = event.get_function_calls()
if not event_function_calls:
continue
for function_call in event_function_calls:
if function_call.id not in confirmation_fc_ids:
continue
args = function_call.args
if 'originalFunctionCall' not in args:
continue
original_function_call = types.FunctionCall(
**args['originalFunctionCall']
)
tool_confirmation_dict[original_function_call.id] = (
confirmations_by_fc_id[function_call.id]
)
original_fcs_dict[original_function_call.id] = original_function_call
return tool_confirmation_dict, original_fcs_dict
class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
"""Handles tool confirmation information to build the LLM request."""
@@ -53,14 +113,12 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
if not events:
return
request_confirmation_function_responses = (
dict()
) # {function call id, tool confirmation}
# Step 1: Find the last user-authored event and parse confirmation
# responses from it.
confirmations_by_fc_id: dict[str, ToolConfirmation] = {}
confirmation_event_index = -1
for k in range(len(events) - 1, -1, -1):
event = events[k]
# Find the first event authored by user
if not event.author or event.author != 'user':
continue
responses = event.get_function_responses()
@@ -70,101 +128,58 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor):
for function_response in responses:
if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME:
continue
# Find the FunctionResponse event that contains the user provided tool
# confirmation
if (
confirmations_by_fc_id[function_response.id] = _parse_tool_confirmation(
function_response.response
and len(function_response.response.values()) == 1
and 'response' in function_response.response.keys()
):
# ADK client must send a resuming run request with a function response
# that always encapsulate the confirmation result with a 'response'
# key
tool_confirmation = ToolConfirmation.model_validate(
json.loads(function_response.response['response'])
)
else:
tool_confirmation = ToolConfirmation.model_validate(
function_response.response
)
request_confirmation_function_responses[function_response.id] = (
tool_confirmation
)
confirmation_event_index = k
break
if not request_confirmation_function_responses:
if not confirmations_by_fc_id:
return
for i in range(len(events) - 2, -1, -1):
# Step 2: Resolve confirmation targets using extracted helper.
confirmation_fc_ids = set(confirmations_by_fc_id.keys())
tools_to_resume_with_confirmation, tools_to_resume_with_args = (
_resolve_confirmation_targets(
events, confirmation_fc_ids, confirmations_by_fc_id
)
)
if not tools_to_resume_with_confirmation:
return
# Step 3: Remove tools that have already been confirmed (dedup).
for i in range(len(events) - 1, confirmation_event_index, -1):
event = events[i]
# Find the system generated FunctionCall event requesting the tool
# confirmation
function_calls = event.get_function_calls()
if not function_calls:
fr_list = event.get_function_responses()
if not fr_list:
continue
tools_to_resume_with_confirmation = (
dict()
) # {Function call id, tool confirmation}
tools_to_resume_with_args = dict() # {Function call id, function calls}
for function_call in function_calls:
if (
function_call.id
not in request_confirmation_function_responses.keys()
):
continue
args = function_call.args
if 'originalFunctionCall' not in args:
continue
original_function_call = types.FunctionCall(
**args['originalFunctionCall']
)
tools_to_resume_with_confirmation[original_function_call.id] = (
request_confirmation_function_responses[function_call.id]
)
tools_to_resume_with_args[original_function_call.id] = (
original_function_call
)
for function_response in fr_list:
if function_response.id in tools_to_resume_with_confirmation:
tools_to_resume_with_confirmation.pop(function_response.id)
tools_to_resume_with_args.pop(function_response.id)
if not tools_to_resume_with_confirmation:
continue
break
# Remove the tools that have already been confirmed.
for i in range(len(events) - 1, confirmation_event_index, -1):
event = events[i]
function_response = event.get_function_responses()
if not function_response:
continue
for function_response in event.get_function_responses():
if function_response.id in tools_to_resume_with_confirmation:
tools_to_resume_with_confirmation.pop(function_response.id)
tools_to_resume_with_args.pop(function_response.id)
if not tools_to_resume_with_confirmation:
break
if not tools_to_resume_with_confirmation:
continue
if function_response_event := await functions.handle_function_call_list_async(
invocation_context,
tools_to_resume_with_args.values(),
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
# There could be parallel function calls that require input
# response would be a dict keyed by function call id
tools_to_resume_with_confirmation.keys(),
tools_to_resume_with_confirmation,
):
yield function_response_event
if not tools_to_resume_with_confirmation:
return
# Step 4: Re-execute the confirmed tools.
if function_response_event := await functions.handle_function_call_list_async(
invocation_context,
tools_to_resume_with_args.values(),
{
tool.name: tool
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
)
},
tools_to_resume_with_confirmation.keys(),
tools_to_resume_with_confirmation,
):
yield function_response_event
return
request_processor = _RequestConfirmationLlmRequestProcessor()
@@ -79,7 +79,9 @@ class TestAuthLlmRequestProcessor:
@pytest.fixture
def mock_auth_config(self):
"""Create a mock AuthConfig."""
return Mock(spec=AuthConfig)
config = Mock(spec=AuthConfig)
config.credential_key = None
return config
@pytest.fixture
def mock_function_response_with_auth(self, mock_auth_config):
@@ -347,10 +349,12 @@ class TestAuthLlmRequestProcessor:
auth_response_1,
auth_response_2,
]
user_event_with_multiple_responses.get_function_calls.return_value = []
# Create system function call events
system_function_call_1 = Mock()
system_function_call_1.id = 'auth_id_1'
system_function_call_1.name = REQUEST_EUC_FUNCTION_CALL_NAME
system_function_call_1.args = {
'function_call_id': 'tool_id_1',
'auth_config': mock_auth_config,
@@ -358,6 +362,7 @@ class TestAuthLlmRequestProcessor:
system_function_call_2 = Mock()
system_function_call_2.id = 'auth_id_2'
system_function_call_2.name = REQUEST_EUC_FUNCTION_CALL_NAME
system_function_call_2.args = {
'function_call_id': 'tool_id_2',
'auth_config': mock_auth_config,