From 706f9fe74db0197e19790ca542d372ce46d0ae87 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 16 Feb 2026 22:51:59 -0800 Subject: [PATCH] refactor: Extract reusable private methods Co-authored-by: Xiang (Sean) Zhou PiperOrigin-RevId: 871128393 --- .../adk/flows/llm_flows/base_llm_flow.py | 615 ++++++++++-------- src/google/adk/flows/llm_flows/single_flow.py | 66 +- tests/unittests/auth/test_toolset_auth.py | 24 +- .../flows/llm_flows/test_base_llm_flow.py | 13 +- 4 files changed, 384 insertions(+), 334 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f1c1cce8..bea21793 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -113,6 +113,331 @@ def _finalize_model_response_event( return finalized_event +async def _resolve_toolset_auth( + invocation_context: InvocationContext, + agent: LlmAgent, +) -> AsyncGenerator[Event, None]: + """Resolves authentication for toolsets before tool listing. + + For each toolset with auth configured via get_auth_config(): + - If credential is available, populate auth_config.exchanged_auth_credential + - If credential is not available, yield auth request event and interrupt + + Args: + invocation_context: The invocation context. + agent: The LLM agent. + + Yields: + Auth request events if any toolset needs authentication. + """ + if not agent.tools: + return + + pending_auth_requests: dict[str, AuthConfig] = {} + callback_context = CallbackContext(invocation_context) + + for tool_union in agent.tools: + if not isinstance(tool_union, BaseToolset): + continue + + auth_config = tool_union.get_auth_config() + if not auth_config: + continue + + try: + credential = await CredentialManager(auth_config).get_auth_credential( + callback_context + ) + except ValueError as e: + # Validation errors from CredentialManager should be logged but not + # block the flow - the toolset may still work without auth + logger.warning( + 'Failed to get auth credential for toolset %s: %s', + type(tool_union).__name__, + e, + ) + credential = None + + if credential: + # Populate in-place for toolset to use in get_tools() + auth_config.exchanged_auth_credential = credential + else: + # Need auth - will interrupt + toolset_id = ( + f'{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}{type(tool_union).__name__}' + ) + pending_auth_requests[toolset_id] = auth_config + + if not pending_auth_requests: + return + + # Build auth requests dict with generated auth requests + auth_requests = { + credential_id: AuthHandler(auth_config).generate_auth_request() + for credential_id, auth_config in pending_auth_requests.items() + } + + # Yield event with auth requests using the shared helper + yield build_auth_request_event( + invocation_context, + auth_requests, + author=agent.name, + ) + + # Interrupt invocation + invocation_context.end_invocation = True + + +async def _handle_before_model_callback( + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, +) -> Optional[LlmResponse]: + """Runs before-model callbacks (plugins then agent callbacks). + + Args: + invocation_context: The invocation context. + llm_request: The LLM request being built. + model_response_event: The model response event for callback context. + + Returns: + An LlmResponse if a callback short-circuits the LLM call, else None. + """ + agent = invocation_context.agent + + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + + # First run callbacks from the plugins. + callback_response = ( + await invocation_context.plugin_manager.run_before_model_callback( + callback_context=callback_context, + llm_request=llm_request, + ) + ) + if callback_response: + return callback_response + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not agent.canonical_before_model_callbacks: + return + for callback in agent.canonical_before_model_callbacks: + callback_response = callback( + callback_context=callback_context, llm_request=llm_request + ) + if inspect.isawaitable(callback_response): + callback_response = await callback_response + if callback_response: + return callback_response + + +async def _handle_after_model_callback( + invocation_context: InvocationContext, + llm_response: LlmResponse, + model_response_event: Event, +) -> Optional[LlmResponse]: + """Runs after-model callbacks (plugins then agent callbacks). + + Also handles grounding metadata injection when google_search_agent is + among the agent's tools. + + Args: + invocation_context: The invocation context. + llm_response: The LLM response to process. + model_response_event: The model response event for callback context. + + Returns: + An altered LlmResponse if a callback modifies it, else None. + """ + agent = invocation_context.agent + + # Add grounding metadata to the response if needed. + # TODO(b/448114567): Remove this function once the workaround is no longer needed. + async def _maybe_add_grounding_metadata( + response: Optional[LlmResponse] = None, + ) -> Optional[LlmResponse]: + readonly_context = ReadonlyContext(invocation_context) + if (tools := invocation_context.canonical_tools_cache) is None: + tools = await agent.canonical_tools(readonly_context) + invocation_context.canonical_tools_cache = tools + + if not any(tool.name == 'google_search_agent' for tool in tools): + return response + ground_metadata = invocation_context.session.state.get( + 'temp:_adk_grounding_metadata', None + ) + if not ground_metadata: + return response + + if not response: + response = llm_response + response.grounding_metadata = ground_metadata + return response + + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + + # First run callbacks from the plugins. + callback_response = ( + await invocation_context.plugin_manager.run_after_model_callback( + callback_context=CallbackContext(invocation_context), + llm_response=llm_response, + ) + ) + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) + + # If no overrides are provided from the plugins, further run the canonical + # callbacks. + if not agent.canonical_after_model_callbacks: + return await _maybe_add_grounding_metadata() + for callback in agent.canonical_after_model_callbacks: + callback_response = callback( + callback_context=callback_context, llm_response=llm_response + ) + if inspect.isawaitable(callback_response): + callback_response = await callback_response + if callback_response: + return await _maybe_add_grounding_metadata(callback_response) + return await _maybe_add_grounding_metadata() + + +async def _run_and_handle_error( + response_generator: AsyncGenerator[LlmResponse, None], + invocation_context: InvocationContext, + llm_request: LlmRequest, + model_response_event: Event, +) -> AsyncGenerator[LlmResponse, None]: + """Wraps an LLM response generator with error callback handling. + + Runs the response generator within a tracing span. If an error occurs, + runs on-model-error callbacks (plugins then agent callbacks). If a + callback returns a response, that response is yielded instead of + re-raising the error. + + Args: + response_generator: The async generator producing LLM responses. + invocation_context: The invocation context. + llm_request: The LLM request. + model_response_event: The model response event. + + Yields: + LlmResponse objects from the generator. + + Raises: + The original model error if no error callback handles it. + """ + agent = invocation_context.agent + if not hasattr(agent, 'canonical_on_model_error_callbacks'): + raise TypeError( + 'Expected agent to have canonical_on_model_error_callbacks' + f' attribute, but got {type(agent)}' + ) + + async def _run_on_model_error_callbacks( + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + error_response = ( + await invocation_context.plugin_manager.run_on_model_error_callback( + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + ) + if error_response is not None: + return error_response + + for callback in agent.canonical_on_model_error_callbacks: + error_response = callback( + callback_context=callback_context, + llm_request=llm_request, + error=error, + ) + if inspect.isawaitable(error_response): + error_response = await error_response + if error_response is not None: + return error_response + + return None + + try: + async with Aclosing(response_generator) as agen: + with tracing.use_generate_content_span( + llm_request, invocation_context, model_response_event + ) as span: + async for llm_response in agen: + tracing.trace_generate_content_result(span, llm_response) + yield llm_response + except Exception as model_error: + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + error_response = await _run_on_model_error_callbacks( + callback_context=callback_context, + llm_request=llm_request, + error=model_error, + ) + if error_response is not None: + yield error_response + else: + raise model_error + + +async def _process_agent_tools( + invocation_context: InvocationContext, + llm_request: LlmRequest, +) -> None: + """Process the agent's tools and populate ``llm_request.tools_dict``. + + Iterates over the agent's ``tools`` list, converts each tool union + (callable, BaseTool, or BaseToolset) into resolved ``BaseTool`` + instances, and calls ``process_llm_request`` on each to register + tool declarations in the request. + + After this function returns, ``llm_request.tools_dict`` maps tool + names to ``BaseTool`` instances ready for function call dispatch. + + Args: + invocation_context: The invocation context (``agent`` is read + from ``invocation_context.agent``). + llm_request: The LLM request to populate with tool declarations. + """ + agent = invocation_context.agent + if not hasattr(agent, 'tools') or not agent.tools: + return + + multiple_tools = len(agent.tools) > 1 + model = agent.canonical_model + for tool_union in agent.tools: + tool_context = ToolContext(invocation_context) + + # If it's a toolset, process it first + if isinstance(tool_union, BaseToolset): + await tool_union.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + from ...agents.llm_agent import _convert_tool_union_to_tools + + # Then process all tools from this tool union + tools = await _convert_tool_union_to_tools( + tool_union, + ReadonlyContext(invocation_context), + model, + multiple_tools, + ) + for tool in tools: + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + class BaseLlmFlow(ABC): """A basic flow that calls the LLM in a loop until a final response is generated. @@ -538,7 +863,7 @@ class BaseLlmFlow(ABC): # Resolve toolset authentication before tool listing. # This ensures credentials are ready before get_tools() is called. async with Aclosing( - self._resolve_toolset_auth(invocation_context, agent) + _resolve_toolset_auth(invocation_context, agent) ) as agen: async for event in agen: yield event @@ -547,112 +872,7 @@ class BaseLlmFlow(ABC): return # Run processors for tools. - - # We may need to wrap some built-in tools if there are other tools - # because the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if not agent.tools: - return - - multiple_tools = len(agent.tools) > 1 - model = agent.canonical_model - for tool_union in agent.tools: - tool_context = ToolContext(invocation_context) - - # If it's a toolset, process it first - if isinstance(tool_union, BaseToolset): - await tool_union.process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) - - from ...agents.llm_agent import _convert_tool_union_to_tools - - # Then process all tools from this tool union - tools = await _convert_tool_union_to_tools( - tool_union, - ReadonlyContext(invocation_context), - model, - multiple_tools, - ) - for tool in tools: - await tool.process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) - - async def _resolve_toolset_auth( - self, - invocation_context: InvocationContext, - agent: LlmAgent, - ) -> AsyncGenerator[Event, None]: - """Resolves authentication for toolsets before tool listing. - - For each toolset with auth configured via get_auth_config(): - - If credential is available, populate auth_config.exchanged_auth_credential - - If credential is not available, yield auth request event and interrupt - - Args: - invocation_context: The invocation context. - agent: The LLM agent. - - Yields: - Auth request events if any toolset needs authentication. - """ - if not agent.tools: - return - - pending_auth_requests: dict[str, AuthConfig] = {} - callback_context = CallbackContext(invocation_context) - - for tool_union in agent.tools: - if not isinstance(tool_union, BaseToolset): - continue - - auth_config = tool_union.get_auth_config() - if not auth_config: - continue - - try: - credential = await CredentialManager(auth_config).get_auth_credential( - callback_context - ) - except ValueError as e: - # Validation errors from CredentialManager should be logged but not - # block the flow - the toolset may still work without auth - logger.warning( - 'Failed to get auth credential for toolset %s: %s', - type(tool_union).__name__, - e, - ) - credential = None - - if credential: - # Populate in-place for toolset to use in get_tools() - auth_config.exchanged_auth_credential = credential - else: - # Need auth - will interrupt - toolset_id = ( - f'{TOOLSET_AUTH_CREDENTIAL_ID_PREFIX}{type(tool_union).__name__}' - ) - pending_auth_requests[toolset_id] = auth_config - - if not pending_auth_requests: - return - - # Build auth requests dict with generated auth requests - auth_requests = { - credential_id: AuthHandler(auth_config).generate_auth_request() - for credential_id, auth_config in pending_auth_requests.items() - } - - # Yield event with auth requests using the shared helper - yield build_auth_request_event( - invocation_context, - auth_requests, - author=agent.name, - ) - - # Interrupt invocation - invocation_context.end_invocation = True + await _process_agent_tools(invocation_context, llm_request) async def _postprocess_async( self, @@ -881,7 +1101,7 @@ class BaseLlmFlow(ABC): model_response_event: Event, ) -> AsyncGenerator[LlmResponse, None]: # Runs before_model_callback if it exists. - if response := await self._handle_before_model_callback( + if response := await _handle_before_model_callback( invocation_context, llm_request, model_response_event ): yield response @@ -906,7 +1126,7 @@ class BaseLlmFlow(ABC): invocation_context.live_request_queue = LiveRequestQueue() responses_generator = self.run_live(invocation_context) async with Aclosing( - self._run_and_handle_error( + _run_and_handle_error( responses_generator, invocation_context, llm_request, @@ -915,7 +1135,7 @@ class BaseLlmFlow(ABC): ) as agen: async for llm_response in agen: # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( + if altered_llm_response := await _handle_after_model_callback( invocation_context, llm_response, model_response_event ): llm_response = altered_llm_response @@ -939,7 +1159,7 @@ class BaseLlmFlow(ABC): == StreamingMode.SSE, ) async with Aclosing( - self._run_and_handle_error( + _run_and_handle_error( responses_generator, invocation_context, llm_request, @@ -955,7 +1175,7 @@ class BaseLlmFlow(ABC): span, ) # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( + if altered_llm_response := await _handle_after_model_callback( invocation_context, llm_response, model_response_event ): llm_response = altered_llm_response @@ -966,100 +1186,6 @@ class BaseLlmFlow(ABC): async for event in agen: yield event - async def _handle_before_model_callback( - self, - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> Optional[LlmResponse]: - agent = invocation_context.agent - - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) - - # First run callbacks from the plugins. - callback_response = ( - await invocation_context.plugin_manager.run_before_model_callback( - callback_context=callback_context, - llm_request=llm_request, - ) - ) - if callback_response: - return callback_response - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not agent.canonical_before_model_callbacks: - return - for callback in agent.canonical_before_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_request=llm_request - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return callback_response - - async def _handle_after_model_callback( - self, - invocation_context: InvocationContext, - llm_response: LlmResponse, - model_response_event: Event, - ) -> Optional[LlmResponse]: - agent = invocation_context.agent - - # Add grounding metadata to the response if needed. - # TODO(b/448114567): Remove this function once the workaround is no longer needed. - async def _maybe_add_grounding_metadata( - response: Optional[LlmResponse] = None, - ) -> Optional[LlmResponse]: - readonly_context = ReadonlyContext(invocation_context) - if (tools := invocation_context.canonical_tools_cache) is None: - tools = await agent.canonical_tools(readonly_context) - invocation_context.canonical_tools_cache = tools - - if not any(tool.name == 'google_search_agent' for tool in tools): - return response - ground_metadata = invocation_context.session.state.get( - 'temp:_adk_grounding_metadata', None - ) - if not ground_metadata: - return response - - if not response: - response = llm_response - response.grounding_metadata = ground_metadata - return response - - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) - - # First run callbacks from the plugins. - callback_response = ( - await invocation_context.plugin_manager.run_after_model_callback( - callback_context=CallbackContext(invocation_context), - llm_response=llm_response, - ) - ) - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) - - # If no overrides are provided from the plugins, further run the canonical - # callbacks. - if not agent.canonical_after_model_callbacks: - return await _maybe_add_grounding_metadata() - for callback in agent.canonical_after_model_callbacks: - callback_response = callback( - callback_context=callback_context, llm_response=llm_response - ) - if inspect.isawaitable(callback_response): - callback_response = await callback_response - if callback_response: - return await _maybe_add_grounding_metadata(callback_response) - return await _maybe_add_grounding_metadata() - def _finalize_model_response_event( self, llm_request: LlmRequest, @@ -1106,83 +1232,6 @@ class BaseLlmFlow(ABC): # model audio here (flush_user_audio=False, flush_model_audio=True). return [] - async def _run_and_handle_error( - self, - response_generator: AsyncGenerator[LlmResponse, None], - invocation_context: InvocationContext, - llm_request: LlmRequest, - model_response_event: Event, - ) -> AsyncGenerator[LlmResponse, None]: - """Runs the response generator and processes the error with plugins. - - Args: - response_generator: The response generator to run. - invocation_context: The invocation context. - llm_request: The LLM request. - model_response_event: The model response event. - - Yields: - A generator of LlmResponse. - """ - - agent = invocation_context.agent - if not hasattr(agent, 'canonical_on_model_error_callbacks'): - raise TypeError( - 'Expected agent to have canonical_on_model_error_callbacks' - f' attribute, but got {type(agent)}' - ) - - async def _run_on_model_error_callbacks( - *, - callback_context: CallbackContext, - llm_request: LlmRequest, - error: Exception, - ) -> Optional[LlmResponse]: - error_response = ( - await invocation_context.plugin_manager.run_on_model_error_callback( - callback_context=callback_context, - llm_request=llm_request, - error=error, - ) - ) - if error_response is not None: - return error_response - - for callback in agent.canonical_on_model_error_callbacks: - error_response = callback( - callback_context=callback_context, - llm_request=llm_request, - error=error, - ) - if inspect.isawaitable(error_response): - error_response = await error_response - if error_response is not None: - return error_response - - return None - - try: - async with Aclosing(response_generator) as agen: - with tracing.use_generate_content_span( - llm_request, invocation_context, model_response_event - ) as span: - async for llm_response in agen: - tracing.trace_generate_content_result(span, llm_response) - yield llm_response - except Exception as model_error: - callback_context = CallbackContext( - invocation_context, event_actions=model_response_event.actions - ) - error_response = await _run_on_model_error_callbacks( - callback_context=callback_context, - llm_request=llm_request, - error=model_error, - ) - if error_response is not None: - yield error_response - else: - raise model_error - def __get_llm(self, invocation_context: InvocationContext) -> BaseLlm: agent = invocation_context.agent if not hasattr(agent, 'canonical_model'): diff --git a/src/google/adk/flows/llm_flows/single_flow.py b/src/google/adk/flows/llm_flows/single_flow.py index d79ec7c4..0a26cdce 100644 --- a/src/google/adk/flows/llm_flows/single_flow.py +++ b/src/google/adk/flows/llm_flows/single_flow.py @@ -34,6 +34,43 @@ from .base_llm_flow import BaseLlmFlow logger = logging.getLogger('google_adk.' + __name__) +def _create_request_processors(): + """Create the standard request processor list for a single-agent flow.""" + return [ + basic.request_processor, + auth_preprocessor.request_processor, + request_confirmation.request_processor, + instructions.request_processor, + identity.request_processor, + contents.request_processor, + # Context cache processor sets up cache config and finds + # existing cache metadata. + context_cache_processor.request_processor, + # Interactions processor extracts previous_interaction_id for + # stateful conversations via the Interactions API. + interactions_processor.request_processor, + # Some implementations of NL Planning mark planning contents + # as thoughts in the post processor. Since these need to be + # unmarked, NL Planning should be after contents. + _nl_planning.request_processor, + # Code execution should be after the contents as it mutates + # the contents to optimize data files. + _code_execution.request_processor, + # Output schema processor adds system instruction and + # set_model_response when both output_schema and tools are + # present. + _output_schema_processor.request_processor, + ] + + +def _create_response_processors(): + """Create the standard response processor list for a single-agent flow.""" + return [ + _nl_planning.response_processor, + _code_execution.response_processor, + ] + + class SingleFlow(BaseLlmFlow): """SingleFlow is the LLM flows that handles tools calls. @@ -43,30 +80,5 @@ class SingleFlow(BaseLlmFlow): def __init__(self): super().__init__() - self.request_processors += [ - basic.request_processor, - auth_preprocessor.request_processor, - request_confirmation.request_processor, - instructions.request_processor, - identity.request_processor, - contents.request_processor, - # Context cache processor sets up cache config and finds existing cache metadata - context_cache_processor.request_processor, - # Interactions processor extracts previous_interaction_id for stateful - # conversations via the Interactions API - interactions_processor.request_processor, - # Some implementations of NL Planning mark planning contents as thoughts - # in the post processor. Since these need to be unmarked, NL Planning - # should be after contents. - _nl_planning.request_processor, - # Code execution should be after the contents as it mutates the contents - # to optimize data files. - _code_execution.request_processor, - # Output schema processor add system instruction and set_model_response - # when both output_schema and tools are present. - _output_schema_processor.request_processor, - ] - self.response_processors += [ - _nl_planning.response_processor, - _code_execution.response_processor, - ] + self.request_processors += _create_request_processors() + self.response_processors += _create_response_processors() diff --git a/tests/unittests/auth/test_toolset_auth.py b/tests/unittests/auth/test_toolset_auth.py index bd4d8f2a..b5efc425 100644 --- a/tests/unittests/auth/test_toolset_auth.py +++ b/tests/unittests/auth/test_toolset_auth.py @@ -31,6 +31,7 @@ from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_preprocessor import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.auth_tool import AuthToolArguments +from google.adk.flows.llm_flows.base_llm_flow import _resolve_toolset_auth from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.flows.llm_flows.base_llm_flow import TOOLSET_AUTH_CREDENTIAL_ID_PREFIX as FLOW_PREFIX from google.adk.flows.llm_flows.functions import build_auth_request_event @@ -119,14 +120,6 @@ class TestResolveToolsetAuth: agent.tools = [] return agent - @pytest.fixture - def flow(self): - """Create a BaseLlmFlow instance for testing.""" - # BaseLlmFlow is abstract, but we can still test _resolve_toolset_auth - flow = Mock(spec=BaseLlmFlow) - flow._resolve_toolset_auth = BaseLlmFlow._resolve_toolset_auth - return flow - @pytest.mark.asyncio async def test_no_tools_returns_no_events( self, mock_invocation_context, mock_agent @@ -134,9 +127,8 @@ class TestResolveToolsetAuth: """Test that no events are yielded when agent has no tools.""" mock_agent.tools = [] - flow = BaseLlmFlow.__new__(BaseLlmFlow) events = [] - async for event in flow._resolve_toolset_auth( + async for event in _resolve_toolset_auth( mock_invocation_context, mock_agent ): events.append(event) @@ -152,9 +144,8 @@ class TestResolveToolsetAuth: toolset = MockToolset(auth_config=None) mock_agent.tools = [toolset] - flow = BaseLlmFlow.__new__(BaseLlmFlow) events = [] - async for event in flow._resolve_toolset_auth( + async for event in _resolve_toolset_auth( mock_invocation_context, mock_agent ): events.append(event) @@ -184,9 +175,8 @@ class TestResolveToolsetAuth: mock_manager.get_auth_credential = AsyncMock(return_value=mock_credential) MockCredentialManager.return_value = mock_manager - flow = BaseLlmFlow.__new__(BaseLlmFlow) events = [] - async for event in flow._resolve_toolset_auth( + async for event in _resolve_toolset_auth( mock_invocation_context, mock_agent ): events.append(event) @@ -213,9 +203,8 @@ class TestResolveToolsetAuth: mock_manager.get_auth_credential = AsyncMock(return_value=None) MockCredentialManager.return_value = mock_manager - flow = BaseLlmFlow.__new__(BaseLlmFlow) events = [] - async for event in flow._resolve_toolset_auth( + async for event in _resolve_toolset_auth( mock_invocation_context, mock_agent ): events.append(event) @@ -258,9 +247,8 @@ class TestResolveToolsetAuth: mock_manager.get_auth_credential = AsyncMock(return_value=None) MockCredentialManager.return_value = mock_manager - flow = BaseLlmFlow.__new__(BaseLlmFlow) events = [] - async for event in flow._resolve_toolset_auth( + async for event in _resolve_toolset_auth( mock_invocation_context, mock_agent ): events.append(event) diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 2f4c1534..3dfadbca 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -19,6 +19,7 @@ from unittest.mock import AsyncMock from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import _handle_after_model_callback from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest @@ -285,7 +286,7 @@ async def test_handle_after_model_callback_grounding_with_no_callbacks( ) flow = BaseLlmFlowForTesting() - result = await flow._handle_after_model_callback( + result = await _handle_after_model_callback( invocation_context, llm_response, event ) @@ -342,7 +343,7 @@ async def test_handle_after_model_callback_grounding_with_callback_override( ) flow = BaseLlmFlowForTesting() - result = await flow._handle_after_model_callback( + result = await _handle_after_model_callback( invocation_context, llm_response, event ) @@ -404,7 +405,7 @@ async def test_handle_after_model_callback_grounding_with_plugin_override( ) flow = BaseLlmFlowForTesting() - result = await flow._handle_after_model_callback( + result = await _handle_after_model_callback( invocation_context, llm_response, event ) @@ -461,13 +462,13 @@ async def test_handle_after_model_callback_caches_canonical_tools(): flow = BaseLlmFlowForTesting() # Call _handle_after_model_callback multiple times with the same context - result1 = await flow._handle_after_model_callback( + result1 = await _handle_after_model_callback( invocation_context, llm_response, event ) - result2 = await flow._handle_after_model_callback( + result2 = await _handle_after_model_callback( invocation_context, llm_response, event ) - result3 = await flow._handle_after_model_callback( + result3 = await _handle_after_model_callback( invocation_context, llm_response, event )