From c59afc21cbed27d1328872cdc2b0e182ab2ca6c8 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 2 Mar 2026 14:27:07 -0800 Subject: [PATCH] refactor: extract reusable functions from hitl and auth preprocessor Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 877578253 --- src/google/adk/auth/auth_preprocessor.py | 215 +++++++++++------- .../flows/llm_flows/request_confirmation.py | 189 ++++++++------- .../unittests/auth/test_auth_preprocessor.py | 7 +- 3 files changed, 235 insertions(+), 176 deletions(-) diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index 37ad6745..76dd2dda 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator from typing_extensions import override @@ -25,6 +26,7 @@ from ..flows.llm_flows import functions from ..flows.llm_flows._base_llm_processor import BaseLlmRequestProcessor from ..flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME from ..models.llm_request import LlmRequest +from ..sessions.state import State from .auth_handler import AuthHandler from .auth_tool import AuthConfig from .auth_tool import AuthToolArguments @@ -35,6 +37,93 @@ from .auth_tool import AuthToolArguments TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' +async def _store_auth_and_collect_resume_targets( + events: list[Event], + auth_fc_ids: set[str], + auth_responses: dict[str, Any], + state: State, +) -> set[str]: + """Store auth credentials and return original function call IDs to resume. + + Scans session events for ``adk_request_credential`` function calls whose + IDs are in *auth_fc_ids*, extracts ``credential_key`` from their + ``AuthToolArguments`` args, merges ``credential_key`` into the + corresponding auth response, stores credentials via ``AuthHandler``, + and returns the set of original function call IDs that should be + re-executed (excluding toolset auth). + + Args: + events: Session events to scan. + auth_fc_ids: IDs of ``adk_request_credential`` function calls to match. + auth_responses: Mapping of FC ID -> auth config response dict from the + client. + state: Session state for temporary credential storage. + + Returns: + Set of original function call IDs to resume. + """ + # Step 1: Scan events for matching adk_request_credential function calls + # to extract AuthToolArguments (contains credential_key). + requested_auth_config_by_id: dict[str, AuthConfig] = {} + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + try: + for function_call in event_function_calls: + if ( + function_call.id in auth_fc_ids + and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME + ): + args = AuthToolArguments.model_validate(function_call.args) + requested_auth_config_by_id[function_call.id] = args.auth_config + except TypeError: + continue + + # Step 2: Store credentials. Merge credential_key from the original + # request into the client's auth response before storing. + for fc_id in auth_fc_ids: + if fc_id not in auth_responses: + continue + auth_config = AuthConfig.model_validate(auth_responses[fc_id]) + requested_auth_config = requested_auth_config_by_id.get(fc_id) + if ( + requested_auth_config + and requested_auth_config.credential_key is not None + ): + auth_config.credential_key = requested_auth_config.credential_key + await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( + state=state + ) + + # Step 3: Collect original function call IDs to resume, skipping + # toolset auth entries which don't map to a resumable function call. + tools_to_resume: set[str] = set() + for fc_id in auth_fc_ids: + requested_auth_config = requested_auth_config_by_id.get(fc_id) + if not requested_auth_config: + continue + # Re-parse to get function_call_id (AuthConfig doesn't carry it; + # AuthToolArguments does). + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + for function_call in event_function_calls: + if ( + function_call.id == fc_id + and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME + ): + args = AuthToolArguments.model_validate(function_call.args) + if args.function_call_id.startswith( + TOOLSET_AUTH_CREDENTIAL_ID_PREFIX + ): + continue + tools_to_resume.add(args.function_call_id) + + return tools_to_resume + + class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): """Handles auth information to build the LLM request.""" @@ -49,8 +138,8 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): if not events: return - request_euc_function_call_ids = set() - # find the last event with non-None content + # Find the last user-authored event with function responses to + # identify adk_request_credential responses. last_event_with_content = None for i in range(len(events) - 1, -1, -1): event = events[i] @@ -58,7 +147,6 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): last_event_with_content = event break - # check if the last event with content is authored by user if not last_event_with_content or last_event_with_content.author != 'user': return @@ -66,104 +154,55 @@ class _AuthLlmRequestProcessor(BaseLlmRequestProcessor): if not responses: return - requested_auth_config_by_request_id = {} - # look for auth response + # Collect adk_request_credential function response IDs and their + # response dicts. + auth_fc_ids: set[str] = set() + auth_responses: dict[str, Any] = {} for function_call_response in responses: if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME: continue - # found the function call response for the system long running request euc - # function call - request_euc_function_call_ids.add(function_call_response.id) - - if request_euc_function_call_ids: - for event in events: - function_calls = event.get_function_calls() - if not function_calls: - continue - try: - for function_call in function_calls: - if ( - function_call.id in request_euc_function_call_ids - and function_call.name == REQUEST_EUC_FUNCTION_CALL_NAME - ): - args = AuthToolArguments.model_validate(function_call.args) - requested_auth_config_by_request_id[function_call.id] = ( - args.auth_config - ) - except TypeError: - continue - - for function_call_response in responses: - if function_call_response.name != REQUEST_EUC_FUNCTION_CALL_NAME: - continue - - auth_config = AuthConfig.model_validate(function_call_response.response) - requested_auth_config = requested_auth_config_by_request_id.get( - function_call_response.id - ) - if ( - requested_auth_config - and requested_auth_config.credential_key is not None - ): - auth_config.credential_key = requested_auth_config.credential_key - await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( - state=invocation_context.session.state + auth_fc_ids.add(function_call_response.id) + auth_responses[function_call_response.id] = ( + function_call_response.response ) - if not request_euc_function_call_ids: + if not auth_fc_ids: return + # Store credentials and collect tools to resume. + tools_to_resume = await _store_auth_and_collect_resume_targets( + events, auth_fc_ids, auth_responses, invocation_context.session.state + ) + + if not tools_to_resume: + return + + # Find the original function call event and re-execute the tools + # that needed auth. for i in range(len(events) - 2, -1, -1): event = events[i] - # looking for the system long running request euc function call function_calls = event.get_function_calls() if not function_calls: continue - tools_to_resume = set() - - for function_call in function_calls: - if function_call.id not in request_euc_function_call_ids: - continue - args = AuthToolArguments.model_validate(function_call.args) - - # Skip toolset auth - auth response is already stored in session state - # and we don't need to resume a function call for toolsets - if args.function_call_id.startswith(TOOLSET_AUTH_CREDENTIAL_ID_PREFIX): - continue - - tools_to_resume.add(args.function_call_id) - if not tools_to_resume: - continue - - # found the system long running request euc function call - # looking for original function call that requests euc - for j in range(i - 1, -1, -1): - event = events[j] - function_calls = event.get_function_calls() - if not function_calls: - continue - - if any([ - function_call.id in tools_to_resume - for function_call in function_calls - ]): - if function_response_event := await functions.handle_function_calls_async( - invocation_context, - event, - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # there could be parallel function calls that require auth - # auth response would be a dict keyed by function call id - tools_to_resume, - ): - yield function_response_event - return - return + if any([ + function_call.id in tools_to_resume + for function_call in function_calls + ]): + if function_response_event := await functions.handle_function_calls_async( + invocation_context, + event, + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + tools_to_resume, + ): + yield function_response_event + return + return request_processor = _AuthLlmRequestProcessor() diff --git a/src/google/adk/flows/llm_flows/request_confirmation.py b/src/google/adk/flows/llm_flows/request_confirmation.py index f7b7f7f6..d066db79 100644 --- a/src/google/adk/flows/llm_flows/request_confirmation.py +++ b/src/google/adk/flows/llm_flows/request_confirmation.py @@ -15,6 +15,7 @@ from __future__ import annotations import json import logging +from typing import Any from typing import AsyncGenerator from typing import TYPE_CHECKING @@ -37,6 +38,65 @@ if TYPE_CHECKING: logger = logging.getLogger('google_adk.' + __name__) +def _parse_tool_confirmation(response: dict[str, Any]) -> ToolConfirmation: + """Parse ToolConfirmation from a function response dict. + + Handles both the direct dict format and the ADK client's + ``{'response': json_string}`` wrapper format. + + """ + if response and len(response.values()) == 1 and 'response' in response.keys(): + return ToolConfirmation.model_validate(json.loads(response['response'])) + return ToolConfirmation.model_validate(response) + + +def _resolve_confirmation_targets( + events: list[Event], + confirmation_fc_ids: set[str], + confirmations_by_fc_id: dict[str, ToolConfirmation], +) -> tuple[dict[str, ToolConfirmation], dict[str, types.FunctionCall]]: + """Find original function calls for confirmed tools. + + Scans events for ``adk_request_confirmation`` function calls whose IDs + are in *confirmation_fc_ids*, extracts the ``originalFunctionCall`` from + their args, and maps each confirmation to the original FC ID. + + Args: + events: Session events to scan. + confirmation_fc_ids: IDs of ``adk_request_confirmation`` function calls. + confirmations_by_fc_id: Mapping of confirmation FC ID -> + ``ToolConfirmation``. + + Returns: + Tuple of ``(tool_confirmation_dict, original_fcs_dict)`` where both + are keyed by the ORIGINAL function call IDs. + """ + tool_confirmation_dict: dict[str, ToolConfirmation] = {} + original_fcs_dict: dict[str, types.FunctionCall] = {} + + for event in events: + event_function_calls = event.get_function_calls() + if not event_function_calls: + continue + + for function_call in event_function_calls: + if function_call.id not in confirmation_fc_ids: + continue + + args = function_call.args + if 'originalFunctionCall' not in args: + continue + original_function_call = types.FunctionCall( + **args['originalFunctionCall'] + ) + tool_confirmation_dict[original_function_call.id] = ( + confirmations_by_fc_id[function_call.id] + ) + original_fcs_dict[original_function_call.id] = original_function_call + + return tool_confirmation_dict, original_fcs_dict + + class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): """Handles tool confirmation information to build the LLM request.""" @@ -53,14 +113,12 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): if not events: return - request_confirmation_function_responses = ( - dict() - ) # {function call id, tool confirmation} - + # Step 1: Find the last user-authored event and parse confirmation + # responses from it. + confirmations_by_fc_id: dict[str, ToolConfirmation] = {} confirmation_event_index = -1 for k in range(len(events) - 1, -1, -1): event = events[k] - # Find the first event authored by user if not event.author or event.author != 'user': continue responses = event.get_function_responses() @@ -70,101 +128,58 @@ class _RequestConfirmationLlmRequestProcessor(BaseLlmRequestProcessor): for function_response in responses: if function_response.name != REQUEST_CONFIRMATION_FUNCTION_CALL_NAME: continue - - # Find the FunctionResponse event that contains the user provided tool - # confirmation - if ( + confirmations_by_fc_id[function_response.id] = _parse_tool_confirmation( function_response.response - and len(function_response.response.values()) == 1 - and 'response' in function_response.response.keys() - ): - # ADK client must send a resuming run request with a function response - # that always encapsulate the confirmation result with a 'response' - # key - tool_confirmation = ToolConfirmation.model_validate( - json.loads(function_response.response['response']) - ) - else: - tool_confirmation = ToolConfirmation.model_validate( - function_response.response - ) - request_confirmation_function_responses[function_response.id] = ( - tool_confirmation ) confirmation_event_index = k break - if not request_confirmation_function_responses: + if not confirmations_by_fc_id: return - for i in range(len(events) - 2, -1, -1): + # Step 2: Resolve confirmation targets using extracted helper. + confirmation_fc_ids = set(confirmations_by_fc_id.keys()) + tools_to_resume_with_confirmation, tools_to_resume_with_args = ( + _resolve_confirmation_targets( + events, confirmation_fc_ids, confirmations_by_fc_id + ) + ) + + if not tools_to_resume_with_confirmation: + return + + # Step 3: Remove tools that have already been confirmed (dedup). + for i in range(len(events) - 1, confirmation_event_index, -1): event = events[i] - # Find the system generated FunctionCall event requesting the tool - # confirmation - function_calls = event.get_function_calls() - if not function_calls: + fr_list = event.get_function_responses() + if not fr_list: continue - tools_to_resume_with_confirmation = ( - dict() - ) # {Function call id, tool confirmation} - tools_to_resume_with_args = dict() # {Function call id, function calls} - - for function_call in function_calls: - if ( - function_call.id - not in request_confirmation_function_responses.keys() - ): - continue - - args = function_call.args - if 'originalFunctionCall' not in args: - continue - original_function_call = types.FunctionCall( - **args['originalFunctionCall'] - ) - tools_to_resume_with_confirmation[original_function_call.id] = ( - request_confirmation_function_responses[function_call.id] - ) - tools_to_resume_with_args[original_function_call.id] = ( - original_function_call - ) + for function_response in fr_list: + if function_response.id in tools_to_resume_with_confirmation: + tools_to_resume_with_confirmation.pop(function_response.id) + tools_to_resume_with_args.pop(function_response.id) if not tools_to_resume_with_confirmation: - continue + break - # Remove the tools that have already been confirmed. - for i in range(len(events) - 1, confirmation_event_index, -1): - event = events[i] - function_response = event.get_function_responses() - if not function_response: - continue - - for function_response in event.get_function_responses(): - if function_response.id in tools_to_resume_with_confirmation: - tools_to_resume_with_confirmation.pop(function_response.id) - tools_to_resume_with_args.pop(function_response.id) - if not tools_to_resume_with_confirmation: - break - - if not tools_to_resume_with_confirmation: - continue - - if function_response_event := await functions.handle_function_call_list_async( - invocation_context, - tools_to_resume_with_args.values(), - { - tool.name: tool - for tool in await agent.canonical_tools( - ReadonlyContext(invocation_context) - ) - }, - # There could be parallel function calls that require input - # response would be a dict keyed by function call id - tools_to_resume_with_confirmation.keys(), - tools_to_resume_with_confirmation, - ): - yield function_response_event + if not tools_to_resume_with_confirmation: return + # Step 4: Re-execute the confirmed tools. + if function_response_event := await functions.handle_function_call_list_async( + invocation_context, + tools_to_resume_with_args.values(), + { + tool.name: tool + for tool in await agent.canonical_tools( + ReadonlyContext(invocation_context) + ) + }, + tools_to_resume_with_confirmation.keys(), + tools_to_resume_with_confirmation, + ): + yield function_response_event + return + request_processor = _RequestConfirmationLlmRequestProcessor() diff --git a/tests/unittests/auth/test_auth_preprocessor.py b/tests/unittests/auth/test_auth_preprocessor.py index 04a64fc5..fb45cc34 100644 --- a/tests/unittests/auth/test_auth_preprocessor.py +++ b/tests/unittests/auth/test_auth_preprocessor.py @@ -79,7 +79,9 @@ class TestAuthLlmRequestProcessor: @pytest.fixture def mock_auth_config(self): """Create a mock AuthConfig.""" - return Mock(spec=AuthConfig) + config = Mock(spec=AuthConfig) + config.credential_key = None + return config @pytest.fixture def mock_function_response_with_auth(self, mock_auth_config): @@ -347,10 +349,12 @@ class TestAuthLlmRequestProcessor: auth_response_1, auth_response_2, ] + user_event_with_multiple_responses.get_function_calls.return_value = [] # Create system function call events system_function_call_1 = Mock() system_function_call_1.id = 'auth_id_1' + system_function_call_1.name = REQUEST_EUC_FUNCTION_CALL_NAME system_function_call_1.args = { 'function_call_id': 'tool_id_1', 'auth_config': mock_auth_config, @@ -358,6 +362,7 @@ class TestAuthLlmRequestProcessor: system_function_call_2 = Mock() system_function_call_2.id = 'auth_id_2' + system_function_call_2.name = REQUEST_EUC_FUNCTION_CALL_NAME system_function_call_2.args = { 'function_call_id': 'tool_id_2', 'auth_config': mock_auth_config,